all: sync with master; upd chlog
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -270,7 +271,13 @@ type ServerConfig struct {
|
||||
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
||||
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||
OnDNSRequest func(d *proxy.DNSContext)
|
||||
|
||||
// AddrProcConf defines the configuration for the client IP processor.
|
||||
// If nil, [client.EmptyAddrProc] is used.
|
||||
//
|
||||
// TODO(a.garipov): The use of [client.EmptyAddrProc] is a crutch for tests.
|
||||
// Remove that.
|
||||
AddrProcConf *client.DefaultAddrProcConfig
|
||||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
@@ -298,9 +305,6 @@ type ServerConfig struct {
|
||||
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64.
|
||||
DNS64Prefixes []netip.Prefix
|
||||
|
||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
||||
ResolveClients bool
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
UsePrivateRDNS bool
|
||||
@@ -340,6 +344,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: int(srvConf.MaxGoroutines),
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
|
||||
57
internal/dnsforward/dialcontext.go
Normal file
57
internal/dnsforward/dialcontext.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
||||
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
// TODO(a.garipov): Consider making configurable.
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
addrs, err := s.Resolve(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: resolving %q: %v", host, addrs)
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("no addresses for host %q", host)
|
||||
}
|
||||
|
||||
var dialErrs []error
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
conn, err = dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
dialErrs = append(dialErrs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use errors.Join in Go 1.20.
|
||||
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -99,8 +100,17 @@ type Server struct {
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
privateNets netutil.SubnetSet
|
||||
ipset ipsetCtx
|
||||
privateNets netutil.SubnetSet
|
||||
|
||||
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// localResolvers is a DNS proxy instance used to resolve PTR records for
|
||||
// addresses considered private as per the [privateNets].
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
|
||||
@@ -170,6 +180,9 @@ const (
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
//
|
||||
// TODO(a.garipov): How many constructors and initializers does this thing have?
|
||||
// Refactor!
|
||||
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
var localDomainSuffix string
|
||||
if p.LocalDomain == "" {
|
||||
@@ -257,14 +270,25 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) {
|
||||
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
|
||||
}
|
||||
|
||||
// RDNSSettings returns the copy of actual RDNS configuration.
|
||||
func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) {
|
||||
// LocalPTRResolvers returns the current local PTR resolver configuration.
|
||||
func (s *Server) LocalPTRResolvers() (localPTRResolvers []string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers),
|
||||
s.conf.ResolveClients,
|
||||
s.conf.UsePrivateRDNS
|
||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers)
|
||||
}
|
||||
|
||||
// AddrProcConfig returns the current address processing configuration. Only
|
||||
// fields c.UsePrivateRDNS, c.UseRDNS, and c.UseWHOIS are filled.
|
||||
func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return &client.DefaultAddrProcConfig{
|
||||
UsePrivateRDNS: s.conf.UsePrivateRDNS,
|
||||
UseRDNS: s.conf.AddrProcConf.UseRDNS,
|
||||
UseWHOIS: s.conf.AddrProcConf.UseWHOIS,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve - get IP addresses by host name from an upstream server.
|
||||
@@ -292,17 +316,13 @@ const (
|
||||
var _ rdns.Exchanger = (*Server)(nil)
|
||||
|
||||
// Exchange implements the [rdns.Exchanger] interface for *Server.
|
||||
func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.conf.ResolveClients {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
return "", 0, fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
arpa = dns.Fqdn(arpa)
|
||||
@@ -318,16 +338,17 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
ctx := &proxy.DNSContext{
|
||||
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: "udp",
|
||||
Req: req,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
var resolver *proxy.Proxy
|
||||
if s.isPrivateIP(ip) {
|
||||
if s.privateNets.Contains(ip.AsSlice()) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
return "", 0, nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
@@ -336,53 +357,48 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
resolver = s.internalProxy
|
||||
}
|
||||
|
||||
if err = resolver.Resolve(ctx); err != nil {
|
||||
return "", err
|
||||
if err = resolver.Resolve(dctx); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return hostFromPTR(ctx.Res)
|
||||
return hostFromPTR(dctx.Res)
|
||||
}
|
||||
|
||||
// hostFromPTR returns domain name from the PTR response or error.
|
||||
func hostFromPTR(resp *dns.Msg) (host string, err error) {
|
||||
func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
|
||||
// Distinguish between NODATA response and a failed request.
|
||||
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
|
||||
return "", fmt.Errorf(
|
||||
return "", 0, fmt.Errorf(
|
||||
"received %s response: %w",
|
||||
dns.RcodeToString[resp.Rcode],
|
||||
ErrRDNSFailed,
|
||||
)
|
||||
}
|
||||
|
||||
var ttlSec uint32
|
||||
|
||||
for _, ans := range resp.Answer {
|
||||
ptr, ok := ans.(*dns.PTR)
|
||||
if ok {
|
||||
return strings.TrimSuffix(ptr.Ptr, "."), nil
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptr.Hdr.Ttl > ttlSec {
|
||||
host = ptr.Ptr
|
||||
ttlSec = ptr.Hdr.Ttl
|
||||
}
|
||||
}
|
||||
|
||||
return "", ErrRDNSNoData
|
||||
}
|
||||
if host != "" {
|
||||
// NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter
|
||||
// case.
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
ttl = time.Duration(ttlSec) * time.Second
|
||||
|
||||
// isPrivateIP returns true if the ip is private.
|
||||
func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) {
|
||||
return s.privateNets.Contains(ip.AsSlice())
|
||||
}
|
||||
|
||||
// ShouldResolveClient returns false if ip is a loopback address, or ip is
|
||||
// private and resolving of private addresses is disabled.
|
||||
func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) {
|
||||
if ip.IsLoopback() {
|
||||
return false
|
||||
return host, ttl, nil
|
||||
}
|
||||
|
||||
isPrivate := s.isPrivateIP(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return s.conf.ResolveClients &&
|
||||
(s.conf.UsePrivateRDNS || !isPrivate)
|
||||
return "", 0, ErrRDNSNoData
|
||||
}
|
||||
|
||||
// Start starts the DNS server.
|
||||
@@ -457,23 +473,27 @@ func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error
|
||||
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
|
||||
}
|
||||
|
||||
// setupResolvers initializes the resolvers for local addresses. For internal
|
||||
// use only.
|
||||
func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. For
|
||||
// internal use only.
|
||||
func (s *Server) setupLocalResolvers() (err error) {
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
if len(localAddrs) == 0 {
|
||||
localAddrs = s.sysResolvers.Get()
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
resolvers = s.sysResolvers.Get()
|
||||
bootstraps = nil
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
}
|
||||
|
||||
localAddrs, err = s.filterOurDNSAddrs(localAddrs)
|
||||
resolvers, err = s.filterOurDNSAddrs(resolvers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
|
||||
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
@@ -486,10 +506,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: upsConfig,
|
||||
UpstreamConfig: uc,
|
||||
},
|
||||
}
|
||||
|
||||
if s.conf.UsePrivateRDNS &&
|
||||
// Only set the upstream config if there are any upstreams. It's safe
|
||||
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
|
||||
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -539,25 +566,48 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("preparing access: %w", err)
|
||||
}
|
||||
|
||||
s.registerHandlers()
|
||||
|
||||
// Set the proxy here because [setupLocalResolvers] sets its values.
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
err = s.setupResolvers(s.conf.LocalPTRResolvers)
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
|
||||
err = s.setupLocalResolvers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
if s.conf.UsePrivateRDNS {
|
||||
proxyConfig.PrivateRDNSUpstreamConfig = s.localResolvers.UpstreamConfig
|
||||
}
|
||||
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
|
||||
s.registerHandlers()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupAddrProc initializes the address processor. For internal use only.
|
||||
func (s *Server) setupAddrProc() {
|
||||
// TODO(a.garipov): This is a crutch for tests; remove.
|
||||
if s.conf.AddrProcConf == nil {
|
||||
s.conf.AddrProcConf = &client.DefaultAddrProcConfig{}
|
||||
}
|
||||
if s.conf.AddrProcConf.AddressUpdater == nil {
|
||||
s.addrProc = client.EmptyAddrProc{}
|
||||
} else {
|
||||
c := s.conf.AddrProcConf
|
||||
c.DialContext = s.DialContext
|
||||
c.PrivateSubnets = s.privateNets
|
||||
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
||||
s.addrProc = client.NewDefaultAddrProc(s.conf.AddrProcConf)
|
||||
|
||||
// Clear the initial addresses to not resolve them again.
|
||||
//
|
||||
// TODO(a.garipov): Consider ways of removing this once more client
|
||||
// logic is moved to package client.
|
||||
c.InitialAddresses = nil
|
||||
}
|
||||
}
|
||||
|
||||
// validateBlockingMode returns an error if the blocking mode data aren't valid.
|
||||
func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) {
|
||||
switch mode {
|
||||
@@ -696,6 +746,11 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||
if conf == nil {
|
||||
conf = &s.conf
|
||||
} else {
|
||||
closeErr := s.addrProc.Close()
|
||||
if closeErr != nil {
|
||||
log.Error("dnsforward: closing address processor: %s", closeErr)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Prepare(conf)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -39,11 +40,29 @@ func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testQuestionTarget is the common question target for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testQuestionTarget = "target.example"
|
||||
|
||||
const (
|
||||
tlsServerName = "testdns.adguard.com"
|
||||
testMessagesCount = 10
|
||||
)
|
||||
|
||||
// testClientAddr is the common net.Addr for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
var testClientAddr net.Addr = &net.TCPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
@@ -53,6 +72,13 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
}
|
||||
|
||||
// packageUpstreamVariableMu is used to serialize access to the package-level
|
||||
// variables of package upstream.
|
||||
//
|
||||
// TODO(s.chzhen): Move these parameters to upstream options and remove this
|
||||
// crutch.
|
||||
var packageUpstreamVariableMu = &sync.Mutex{}
|
||||
|
||||
func createTestServer(
|
||||
t *testing.T,
|
||||
filterConf *filtering.Config,
|
||||
@@ -61,6 +87,9 @@ func createTestServer(
|
||||
) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
packageUpstreamVariableMu.Lock()
|
||||
defer packageUpstreamVariableMu.Unlock()
|
||||
|
||||
rules := `||nxdomain.example.org
|
||||
||NULL.example.org^
|
||||
127.0.0.1 host.example.org
|
||||
@@ -307,11 +336,9 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_timeout(t *testing.T) {
|
||||
const timeout time.Duration = time.Second
|
||||
|
||||
t.Run("custom", func(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
UpstreamTimeout: timeout,
|
||||
UpstreamTimeout: testTimeout,
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockingMode: BlockingModeDefault,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -324,7 +351,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
err = s.Prepare(srvConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, timeout, s.conf.UpstreamTimeout)
|
||||
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
@@ -441,7 +468,14 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
},
|
||||
}
|
||||
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
@@ -480,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
testCases := []struct {
|
||||
host string
|
||||
@@ -545,7 +579,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
Timeout: testTimeout,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
assert.NoErrorf(t, err, "got a response to an invalid query")
|
||||
@@ -928,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
Upstream: aghtest.NewBlockUpstream(hostname, true),
|
||||
})
|
||||
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
ans4, _ := aghtest.HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
@@ -1266,25 +1300,57 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
|
||||
// copy of first answer record with doubled TTL.
|
||||
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(msg.Answer) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
rec := msg.Answer[0]
|
||||
ptr, ok := rec.(*dns.PTR)
|
||||
if !ok {
|
||||
return msg
|
||||
}
|
||||
|
||||
clone := *ptr
|
||||
clone.Hdr.Ttl *= 2
|
||||
msg.Answer = append(msg.Answer, &clone)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
twosHost = "two.two.two.two"
|
||||
localDomainHost = "local.domain"
|
||||
|
||||
defaultTTL = time.Second * 60
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = netip.MustParseAddr("1.1.1.1")
|
||||
twosIP = netip.MustParseAddr("2.2.2.2")
|
||||
localIP = netip.MustParseAddr("192.168.1.1")
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
@@ -1320,53 +1386,65 @@ func TestServer_Exchange(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
req netip.Addr
|
||||
wantErr error
|
||||
locUpstream upstream.Upstream
|
||||
req netip.Addr
|
||||
name string
|
||||
want string
|
||||
wantTTL time.Duration
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: onesIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "empty_answer_error",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: locUpstream,
|
||||
req: netip.MustParseAddr("192.168.1.2"),
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "invalid_answer",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: nonPtrUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "refused",
|
||||
want: "",
|
||||
wantErr: ErrRDNSFailed,
|
||||
locUpstream: refusingUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "longest_ttl",
|
||||
want: twosHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: twosIP,
|
||||
wantTTL: defaultTTL * 2,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -1380,73 +1458,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, eerr := srv.Exchange(tc.req)
|
||||
host, ttl, eerr := srv.Exchange(tc.req)
|
||||
|
||||
require.ErrorIs(t, eerr, tc.wantErr)
|
||||
assert.Equal(t, tc.want, host)
|
||||
assert.Equal(t, tc.wantTTL, ttl)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("resolving_disabled", func(t *testing.T) {
|
||||
srv.conf.UsePrivateRDNS = false
|
||||
|
||||
host, eerr := srv.Exchange(localIP)
|
||||
host, _, eerr := srv.Exchange(localIP)
|
||||
|
||||
require.NoError(t, eerr)
|
||||
assert.Empty(t, host)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_ShouldResolveClient(t *testing.T) {
|
||||
srv := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
want require.BoolAssertionFunc
|
||||
name string
|
||||
resolve bool
|
||||
usePrivate bool
|
||||
}{{
|
||||
name: "default",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "no_rdns",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.False,
|
||||
resolve: false,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "loopback",
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_no_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv.conf.ResolveClients = tc.resolve
|
||||
srv.conf.UsePrivateRDNS = tc.usePrivate
|
||||
|
||||
ok := srv.ShouldResolveClient(tc.ip)
|
||||
tc.want(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler(
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||
// the client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.Settings()
|
||||
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||
// client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||
setts = s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||
|
||||
@@ -124,7 +124,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
cacheMinTTL := s.conf.CacheMinTTL
|
||||
cacheMaxTTL := s.conf.CacheMaxTTL
|
||||
cacheOptimistic := s.conf.CacheOptimistic
|
||||
resolveClients := s.conf.ResolveClients
|
||||
resolveClients := s.conf.AddrProcConf.UseRDNS
|
||||
usePrivateRDNS := s.conf.UsePrivateRDNS
|
||||
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
|
||||
|
||||
@@ -314,8 +314,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||
setIfNotNil(&s.conf.ResolveClients, dc.ResolveClients)
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS)
|
||||
|
||||
return s.setConfigRestartable(dc)
|
||||
}
|
||||
@@ -335,6 +333,9 @@ func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {
|
||||
// setConfigRestartable sets the parameters which trigger a restart.
|
||||
// shouldRestart is true if the server should be restarted to apply changes.
|
||||
// s.serverLock is expected to be locked.
|
||||
//
|
||||
// TODO(a.garipov): Some of these could probably be updated without a restart.
|
||||
// Inspect and consider refactoring.
|
||||
func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
for _, hasSet := range []bool{
|
||||
setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams),
|
||||
@@ -347,6 +348,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
||||
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
||||
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
||||
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
|
||||
} {
|
||||
shouldRestart = shouldRestart || hasSet
|
||||
if shouldRestart {
|
||||
|
||||
@@ -30,6 +30,7 @@ type dnsContext struct {
|
||||
setts *filtering.Settings
|
||||
|
||||
result *filtering.Result
|
||||
|
||||
// origResp is the response received from upstream. It is set when the
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
@@ -48,13 +49,13 @@ type dnsContext struct {
|
||||
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
|
||||
clientID string
|
||||
|
||||
// startTime is the time at which the processing of the request has started.
|
||||
startTime time.Time
|
||||
|
||||
// origQuestion is the question received from the client. It is set
|
||||
// when the request is modified by rewrites.
|
||||
origQuestion dns.Question
|
||||
|
||||
// startTime is the time at which the processing of the request has started.
|
||||
startTime time.Time
|
||||
|
||||
// protectionEnabled shows if the filtering is enabled, and if the
|
||||
// server's DNS filter is ready.
|
||||
protectionEnabled bool
|
||||
@@ -160,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||
// own DoH server.
|
||||
//
|
||||
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
||||
const mozillaFQDN = "use-application-dns.net."
|
||||
|
||||
// healthcheckFQDN is a reserved domain-name used for healthchecking.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||
// grant requests to register test names in the normal way to any person or
|
||||
// entity, making domain names under the .test TLD free to use in internal
|
||||
// purposes.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||
const healthcheckFQDN = "healthcheck.adguardhome.test."
|
||||
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
//
|
||||
@@ -169,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
s.processClientIP(pctx.Addr)
|
||||
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
|
||||
@@ -177,28 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
if s.conf.OnDNSRequest != nil {
|
||||
s.conf.OnDNSRequest(pctx)
|
||||
}
|
||||
|
||||
// Disable Mozilla DoH.
|
||||
//
|
||||
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." {
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Handle a reserved domain healthcheck.adguardhome.test.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||
// grant requests to register test names in the normal way to any person or
|
||||
// entity, making domain names under test. TLD free to use in internal
|
||||
// purposes.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||
if q.Name == "healthcheck.adguardhome.test." {
|
||||
if q.Name == healthcheckFQDN {
|
||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
|
||||
@@ -213,11 +217,28 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
||||
dctx.setts = s.getClientRequestFilteringSettings(dctx)
|
||||
dctx.setts = s.clientRequestFilteringSettings(dctx)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||
func (s *Server) processClientIP(addr net.Addr) {
|
||||
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
|
||||
if clientIP == (netip.Addr{}) {
|
||||
log.Info("dnsforward: warning: bad client addr %q", addr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Do not assign s.addrProc to a local variable to then use, since this lock
|
||||
// also serializes the closure of s.addrProc.
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
s.addrProc.Process(clientIP)
|
||||
}
|
||||
|
||||
func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIPLock.Lock()
|
||||
defer s.tableHostToIPLock.Unlock()
|
||||
@@ -698,6 +719,18 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
if s.conf.UsePrivateRDNS {
|
||||
s.recDetector.add(*pctx.Req)
|
||||
if err := s.localResolvers.Resolve(pctx); err != nil {
|
||||
// Generate the server failure if the private upstream configuration
|
||||
// is empty.
|
||||
//
|
||||
// TODO(e.burkov): Get rid of this crutch once the local resolvers
|
||||
// logic is moved to the dnsproxy completely.
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
pctx.Res = s.genServerFailure(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,96 @@ const (
|
||||
ddrTestFQDN = ddrTestDomainName + "."
|
||||
)
|
||||
|
||||
func TestServer_ProcessInitial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
target string
|
||||
wantRCode rules.RCode
|
||||
qType rules.RRType
|
||||
aaaaDisabled bool
|
||||
wantRC resultCode
|
||||
}{{
|
||||
name: "success",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: -1,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeSuccess,
|
||||
}, {
|
||||
name: "aaaa_disabled",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
qType: dns.TypeAAAA,
|
||||
aaaaDisabled: true,
|
||||
wantRC: resultCodeFinish,
|
||||
}, {
|
||||
name: "aaaa_disabled_a",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: -1,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: true,
|
||||
wantRC: resultCodeSuccess,
|
||||
}, {
|
||||
name: "mozilla_canary",
|
||||
target: mozillaFQDN,
|
||||
wantRCode: dns.RcodeNameError,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeFinish,
|
||||
}, {
|
||||
name: "adguardhome_healthcheck",
|
||||
target: healthcheckFQDN,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeFinish,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, c, nil)
|
||||
|
||||
var gotAddr netip.Addr
|
||||
s.addrProc = &aghtest.AddressProcessor{
|
||||
OnProcess: func(ip netip.Addr) { gotAddr = ip },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: createTestMessageWithType(tc.target, tc.qType),
|
||||
Addr: testClientAddr,
|
||||
RequestID: 1234,
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processInitial(dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
|
||||
|
||||
if tc.wantRCode > 0 {
|
||||
gotResp := dctx.proxyCtx.Res
|
||||
require.NotNil(t, gotResp)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, gotResp.Rcode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
@@ -64,7 +155,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}{{
|
||||
name: "pass_host",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: "example.net.",
|
||||
host: testQuestionTarget,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
@@ -234,33 +325,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
host string
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
isLocalCli bool
|
||||
}{{
|
||||
wantIP: knownIP,
|
||||
name: "local_client_success",
|
||||
host: "example.lan",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "local_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_known_host",
|
||||
host: "example.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}}
|
||||
@@ -332,52 +423,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
host string
|
||||
suffix string
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
qtyp uint16
|
||||
}{{
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external_non_a",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeCNAME,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_internal",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_unknown",
|
||||
host: "example-new.lan",
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_aaaa",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
suffix: "custom",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
@@ -560,10 +651,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
var dnsCtx *dnsContext
|
||||
setup := func(use bool) {
|
||||
proxyCtx = &proxy.DNSContext{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
},
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
Addr: testClientAddr,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
}
|
||||
dnsCtx = &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
@@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
@@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts(
|
||||
|
||||
// extractUpstreamHost returns the hostname of addr without port with an
|
||||
// assumption that any address passed here has already been successfully parsed
|
||||
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
|
||||
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
|
||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||
func extractUpstreamHost(addr string) (host string) {
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user