all: sync with master; upd chlog
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -39,11 +40,29 @@ func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testQuestionTarget is the common question target for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testQuestionTarget = "target.example"
|
||||
|
||||
const (
|
||||
tlsServerName = "testdns.adguard.com"
|
||||
testMessagesCount = 10
|
||||
)
|
||||
|
||||
// testClientAddr is the common net.Addr for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
var testClientAddr net.Addr = &net.TCPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
@@ -53,6 +72,13 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
}
|
||||
|
||||
// packageUpstreamVariableMu is used to serialize access to the package-level
|
||||
// variables of package upstream.
|
||||
//
|
||||
// TODO(s.chzhen): Move these parameters to upstream options and remove this
|
||||
// crutch.
|
||||
var packageUpstreamVariableMu = &sync.Mutex{}
|
||||
|
||||
func createTestServer(
|
||||
t *testing.T,
|
||||
filterConf *filtering.Config,
|
||||
@@ -61,6 +87,9 @@ func createTestServer(
|
||||
) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
packageUpstreamVariableMu.Lock()
|
||||
defer packageUpstreamVariableMu.Unlock()
|
||||
|
||||
rules := `||nxdomain.example.org
|
||||
||NULL.example.org^
|
||||
127.0.0.1 host.example.org
|
||||
@@ -307,11 +336,9 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_timeout(t *testing.T) {
|
||||
const timeout time.Duration = time.Second
|
||||
|
||||
t.Run("custom", func(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
UpstreamTimeout: timeout,
|
||||
UpstreamTimeout: testTimeout,
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockingMode: BlockingModeDefault,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -324,7 +351,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
err = s.Prepare(srvConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, timeout, s.conf.UpstreamTimeout)
|
||||
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
@@ -441,7 +468,14 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
},
|
||||
}
|
||||
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
@@ -480,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
testCases := []struct {
|
||||
host string
|
||||
@@ -545,7 +579,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
Timeout: testTimeout,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
assert.NoErrorf(t, err, "got a response to an invalid query")
|
||||
@@ -928,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
Upstream: aghtest.NewBlockUpstream(hostname, true),
|
||||
})
|
||||
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
ans4, _ := aghtest.HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
@@ -1266,25 +1300,57 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
|
||||
// copy of first answer record with doubled TTL.
|
||||
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(msg.Answer) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
rec := msg.Answer[0]
|
||||
ptr, ok := rec.(*dns.PTR)
|
||||
if !ok {
|
||||
return msg
|
||||
}
|
||||
|
||||
clone := *ptr
|
||||
clone.Hdr.Ttl *= 2
|
||||
msg.Answer = append(msg.Answer, &clone)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
twosHost = "two.two.two.two"
|
||||
localDomainHost = "local.domain"
|
||||
|
||||
defaultTTL = time.Second * 60
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = netip.MustParseAddr("1.1.1.1")
|
||||
twosIP = netip.MustParseAddr("2.2.2.2")
|
||||
localIP = netip.MustParseAddr("192.168.1.1")
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
@@ -1320,53 +1386,65 @@ func TestServer_Exchange(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
req netip.Addr
|
||||
wantErr error
|
||||
locUpstream upstream.Upstream
|
||||
req netip.Addr
|
||||
name string
|
||||
want string
|
||||
wantTTL time.Duration
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: onesIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "empty_answer_error",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: locUpstream,
|
||||
req: netip.MustParseAddr("192.168.1.2"),
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "invalid_answer",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: nonPtrUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "refused",
|
||||
want: "",
|
||||
wantErr: ErrRDNSFailed,
|
||||
locUpstream: refusingUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "longest_ttl",
|
||||
want: twosHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: twosIP,
|
||||
wantTTL: defaultTTL * 2,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -1380,73 +1458,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, eerr := srv.Exchange(tc.req)
|
||||
host, ttl, eerr := srv.Exchange(tc.req)
|
||||
|
||||
require.ErrorIs(t, eerr, tc.wantErr)
|
||||
assert.Equal(t, tc.want, host)
|
||||
assert.Equal(t, tc.wantTTL, ttl)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("resolving_disabled", func(t *testing.T) {
|
||||
srv.conf.UsePrivateRDNS = false
|
||||
|
||||
host, eerr := srv.Exchange(localIP)
|
||||
host, _, eerr := srv.Exchange(localIP)
|
||||
|
||||
require.NoError(t, eerr)
|
||||
assert.Empty(t, host)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_ShouldResolveClient(t *testing.T) {
|
||||
srv := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
want require.BoolAssertionFunc
|
||||
name string
|
||||
resolve bool
|
||||
usePrivate bool
|
||||
}{{
|
||||
name: "default",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "no_rdns",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.False,
|
||||
resolve: false,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "loopback",
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_no_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv.conf.ResolveClients = tc.resolve
|
||||
srv.conf.UsePrivateRDNS = tc.usePrivate
|
||||
|
||||
ok := srv.ShouldResolveClient(tc.ip)
|
||||
tc.want(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user