all: sync with master
This commit is contained in:
@@ -14,6 +14,8 @@ import (
|
||||
)
|
||||
|
||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||
//
|
||||
// Keep in sync with [client.ValidateClientID].
|
||||
func ValidateClientID(id string) (err error) {
|
||||
err = netutil.ValidateHostnameLabel(id)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -24,7 +25,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// ClientsContainer provides information about preconfigured DNS clients.
|
||||
@@ -40,7 +40,7 @@ type ClientsContainer interface {
|
||||
) (conf *proxy.CustomUpstreamConfig, err error)
|
||||
}
|
||||
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config represents the DNS filtering configuration of AdGuard Home. The zero
|
||||
// Config is empty and ready for use.
|
||||
type Config struct {
|
||||
// Callbacks for other modules
|
||||
@@ -357,10 +357,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
conf.DNSCryptResolverCert = c.ResolverCert
|
||||
}
|
||||
|
||||
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstream servers configured")
|
||||
}
|
||||
|
||||
conf, err = prepareCacheConfig(conf,
|
||||
srvConf.CacheSize,
|
||||
srvConf.CacheMinTTL,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestAnyNameMatches(t *testing.T) {
|
||||
|
||||
@@ -2,54 +2,56 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
|
||||
// actual DNS availability of each upstream.
|
||||
// upstreamConfigValidator parses each section of an upstream configuration into
|
||||
// a corresponding [*proxy.UpstreamConfig] and checks the actual DNS
|
||||
// availability of each upstream.
|
||||
type upstreamConfigValidator struct {
|
||||
// general is the general upstream configuration.
|
||||
general []*upstreamResult
|
||||
// generalUpstreamResults contains upstream results of a general section.
|
||||
generalUpstreamResults map[string]*upstreamResult
|
||||
|
||||
// fallback is the fallback upstream configuration.
|
||||
fallback []*upstreamResult
|
||||
// fallbackUpstreamResults contains upstream results of a fallback section.
|
||||
fallbackUpstreamResults map[string]*upstreamResult
|
||||
|
||||
// private is the private upstream configuration.
|
||||
private []*upstreamResult
|
||||
// privateUpstreamResults contains upstream results of a private section.
|
||||
privateUpstreamResults map[string]*upstreamResult
|
||||
|
||||
// generalParseResults contains parsing results of a general section.
|
||||
generalParseResults []*parseResult
|
||||
|
||||
// fallbackParseResults contains parsing results of a fallback section.
|
||||
fallbackParseResults []*parseResult
|
||||
|
||||
// privateParseResults contains parsing results of a private section.
|
||||
privateParseResults []*parseResult
|
||||
}
|
||||
|
||||
// upstreamResult is a result of validation of an [upstream.Upstream] within an
|
||||
// upstreamResult is a result of parsing of an [upstream.Upstream] within an
|
||||
// [proxy.UpstreamConfig].
|
||||
type upstreamResult struct {
|
||||
// server is the parsed upstream. It is nil when there was an error during
|
||||
// parsing.
|
||||
// server is the parsed upstream.
|
||||
server upstream.Upstream
|
||||
|
||||
// err is the error either from parsing or from checking the upstream.
|
||||
// err is the upstream check error.
|
||||
err error
|
||||
|
||||
// original is the piece of configuration that have either been turned to an
|
||||
// upstream or caused an error.
|
||||
original string
|
||||
|
||||
// isSpecific is true if the upstream is domain-specific.
|
||||
isSpecific bool
|
||||
}
|
||||
|
||||
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
|
||||
// if ur should be sorted before other, and 1 otherwise.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
|
||||
// the end.
|
||||
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
|
||||
return strings.Compare(ur.original, other.original)
|
||||
// parseResult contains a original piece of upstream configuration and a
|
||||
// corresponding error.
|
||||
type parseResult struct {
|
||||
err *proxy.ParseError
|
||||
original string
|
||||
}
|
||||
|
||||
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
||||
@@ -61,97 +63,99 @@ func newUpstreamConfigValidator(
|
||||
private []string,
|
||||
opts *upstream.Options,
|
||||
) (cv *upstreamConfigValidator) {
|
||||
cv = &upstreamConfigValidator{}
|
||||
cv = &upstreamConfigValidator{
|
||||
generalUpstreamResults: map[string]*upstreamResult{},
|
||||
fallbackUpstreamResults: map[string]*upstreamResult{},
|
||||
privateUpstreamResults: map[string]*upstreamResult{},
|
||||
}
|
||||
|
||||
for _, line := range general {
|
||||
cv.general = cv.insertLineResults(cv.general, line, opts)
|
||||
}
|
||||
for _, line := range fallback {
|
||||
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
|
||||
}
|
||||
for _, line := range private {
|
||||
cv.private = cv.insertLineResults(cv.private, line, opts)
|
||||
}
|
||||
conf, err := proxy.ParseUpstreamsConfig(general, opts)
|
||||
cv.generalParseResults = collectErrResults(general, err)
|
||||
insertConfResults(conf, cv.generalUpstreamResults)
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(fallback, opts)
|
||||
cv.fallbackParseResults = collectErrResults(fallback, err)
|
||||
insertConfResults(conf, cv.fallbackUpstreamResults)
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(private, opts)
|
||||
cv.privateParseResults = collectErrResults(private, err)
|
||||
insertConfResults(conf, cv.privateUpstreamResults)
|
||||
|
||||
return cv
|
||||
}
|
||||
|
||||
// insertLineResults parses line and inserts the result into s. It can insert
|
||||
// multiple results as well as none.
|
||||
func (cv *upstreamConfigValidator) insertLineResults(
|
||||
s []*upstreamResult,
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (result []*upstreamResult) {
|
||||
upstreams, isSpecific, err := splitUpstreamLine(line)
|
||||
if err != nil {
|
||||
return cv.insert(s, &upstreamResult{
|
||||
err: err,
|
||||
original: line,
|
||||
})
|
||||
// collectErrResults parses err and returns parsing results containing the
|
||||
// original upstream configuration line and the corresponding error. err can be
|
||||
// nil.
|
||||
func collectErrResults(lines []string, err error) (results []*parseResult) {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var res *upstreamResult
|
||||
if upstreamAddr != "#" {
|
||||
res = cv.parseUpstream(upstreamAddr, opts)
|
||||
} else if !isSpecific {
|
||||
res = &upstreamResult{
|
||||
err: errNotDomainSpecific,
|
||||
original: upstreamAddr,
|
||||
}
|
||||
} else {
|
||||
// limit is a maximum length for upstream configuration lines.
|
||||
const limit = 80
|
||||
|
||||
wrapper, ok := err.(errors.WrapperSlice)
|
||||
if !ok {
|
||||
log.Debug("dnsforward: configvalidator: unwrapping: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
errs := wrapper.Unwrap()
|
||||
results = make([]*parseResult, 0, len(errs))
|
||||
for i, e := range errs {
|
||||
var parseErr *proxy.ParseError
|
||||
if !errors.As(e, &parseErr) {
|
||||
log.Debug("dnsforward: configvalidator: inserting unexpected error %d: %s", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
res.isSpecific = isSpecific
|
||||
s = cv.insert(s, res)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// insert inserts r into slice in a sorted order, except duplicates. slice must
|
||||
// not be nil.
|
||||
func (cv *upstreamConfigValidator) insert(
|
||||
s []*upstreamResult,
|
||||
r *upstreamResult,
|
||||
) (result []*upstreamResult) {
|
||||
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
|
||||
if has {
|
||||
log.Debug("dnsforward: duplicate configuration %q", r.original)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
return slices.Insert(s, i, r)
|
||||
}
|
||||
|
||||
// parseUpstream parses addr and returns the result of parsing. It returns nil
|
||||
// if the specified server points at the default upstream server which is
|
||||
// validated separately.
|
||||
func (cv *upstreamConfigValidator) parseUpstream(
|
||||
addr string,
|
||||
opts *upstream.Options,
|
||||
) (r *upstreamResult) {
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
if proto, _, ok := strings.Cut(addr, "://"); ok {
|
||||
if !slices.Contains(protocols, proto) {
|
||||
return &upstreamResult{
|
||||
err: fmt.Errorf("bad protocol %q", proto),
|
||||
original: addr,
|
||||
}
|
||||
idx := parseErr.Idx
|
||||
line := []rune(lines[idx])
|
||||
if len(line) > limit {
|
||||
line = line[:limit]
|
||||
line[limit-1] = '…'
|
||||
}
|
||||
|
||||
results = append(results, &parseResult{
|
||||
original: string(line),
|
||||
err: parseErr,
|
||||
})
|
||||
}
|
||||
|
||||
ups, err := upstream.AddressToUpstream(addr, opts)
|
||||
return results
|
||||
}
|
||||
|
||||
return &upstreamResult{
|
||||
server: ups,
|
||||
err: err,
|
||||
original: addr,
|
||||
// insertConfResults parses conf and inserts the upstream result into results.
|
||||
// It can insert multiple results as well as none.
|
||||
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
|
||||
insertListResults(conf.Upstreams, results, false)
|
||||
|
||||
for _, ups := range conf.DomainReservedUpstreams {
|
||||
insertListResults(ups, results, true)
|
||||
}
|
||||
|
||||
for _, ups := range conf.SpecifiedDomainUpstreams {
|
||||
insertListResults(ups, results, true)
|
||||
}
|
||||
}
|
||||
|
||||
// insertListResults constructs upstream results from the upstream list and
|
||||
// inserts them into results. It can insert multiple results as well as none.
|
||||
func insertListResults(ups []upstream.Upstream, results map[string]*upstreamResult, specific bool) {
|
||||
for _, u := range ups {
|
||||
addr := u.Address()
|
||||
_, ok := results[addr]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
results[addr] = &upstreamResult{
|
||||
server: u,
|
||||
isSpecific: specific,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,35 +191,30 @@ func (cv *upstreamConfigValidator) check() {
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
|
||||
wg.Add(len(cv.generalUpstreamResults) +
|
||||
len(cv.fallbackUpstreamResults) +
|
||||
len(cv.privateUpstreamResults))
|
||||
|
||||
for _, res := range cv.general {
|
||||
go cv.checkSrv(res, wg, commonChecker)
|
||||
for _, res := range cv.generalUpstreamResults {
|
||||
go checkSrv(res, wg, commonChecker)
|
||||
}
|
||||
for _, res := range cv.fallback {
|
||||
go cv.checkSrv(res, wg, commonChecker)
|
||||
for _, res := range cv.fallbackUpstreamResults {
|
||||
go checkSrv(res, wg, commonChecker)
|
||||
}
|
||||
for _, res := range cv.private {
|
||||
go cv.checkSrv(res, wg, arpaChecker)
|
||||
for _, res := range cv.privateUpstreamResults {
|
||||
go checkSrv(res, wg, arpaChecker)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// checkSrv runs hc on the server from res, if any, and stores any occurred
|
||||
// error in res. wg is always marked done in the end. It used to be called in
|
||||
// a separate goroutine.
|
||||
func (cv *upstreamConfigValidator) checkSrv(
|
||||
res *upstreamResult,
|
||||
wg *sync.WaitGroup,
|
||||
hc *healthchecker,
|
||||
) {
|
||||
// error in res. wg is always marked done in the end. It is intended to be
|
||||
// used as a goroutine.
|
||||
func checkSrv(res *upstreamResult, wg *sync.WaitGroup, hc *healthchecker) {
|
||||
defer log.OnPanic(fmt.Sprintf("dnsforward: checking upstream %s", res.server.Address()))
|
||||
defer wg.Done()
|
||||
|
||||
if res.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
res.err = hc.check(res.server)
|
||||
if res.err != nil && res.isSpecific {
|
||||
res.err = domainSpecificTestError{Err: res.err}
|
||||
@@ -225,65 +224,126 @@ func (cv *upstreamConfigValidator) checkSrv(
|
||||
// close closes all the upstreams that were successfully parsed. It enriches
|
||||
// the results with deferred closing errors.
|
||||
func (cv *upstreamConfigValidator) close() {
|
||||
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
|
||||
for _, r := range slice {
|
||||
if r.server != nil {
|
||||
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||
}
|
||||
all := []map[string]*upstreamResult{
|
||||
cv.generalUpstreamResults,
|
||||
cv.fallbackUpstreamResults,
|
||||
cv.privateUpstreamResults,
|
||||
}
|
||||
|
||||
for _, m := range all {
|
||||
for _, r := range m {
|
||||
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sections of the upstream configuration according to the text label of the
|
||||
// localization.
|
||||
//
|
||||
// Keep in sync with client/src/__locales/en.json.
|
||||
//
|
||||
// TODO(s.chzhen): Refactor.
|
||||
const (
|
||||
generalTextLabel = "upstream_dns"
|
||||
fallbackTextLabel = "fallback_dns_title"
|
||||
privateTextLabel = "local_ptr_title"
|
||||
)
|
||||
|
||||
// status returns all the data collected during parsing, healthcheck, and
|
||||
// closing of the upstreams. The returned map is keyed by the original upstream
|
||||
// configuration piece and contains the corresponding error or "OK" if there was
|
||||
// no error.
|
||||
func (cv *upstreamConfigValidator) status() (results map[string]string) {
|
||||
result := map[string]string{}
|
||||
// Names of the upstream configuration sections for logging.
|
||||
const (
|
||||
generalSection = "general"
|
||||
fallbackSection = "fallback"
|
||||
privateSection = "private"
|
||||
)
|
||||
|
||||
for _, res := range cv.general {
|
||||
resultToStatus("general", res, result)
|
||||
results = map[string]string{}
|
||||
|
||||
for original, res := range cv.generalUpstreamResults {
|
||||
upstreamResultToStatus(generalSection, string(original), res, results)
|
||||
}
|
||||
for _, res := range cv.fallback {
|
||||
resultToStatus("fallback", res, result)
|
||||
for original, res := range cv.fallbackUpstreamResults {
|
||||
upstreamResultToStatus(fallbackSection, string(original), res, results)
|
||||
}
|
||||
for _, res := range cv.private {
|
||||
resultToStatus("private", res, result)
|
||||
for original, res := range cv.privateUpstreamResults {
|
||||
upstreamResultToStatus(privateSection, string(original), res, results)
|
||||
}
|
||||
|
||||
return result
|
||||
parseResultToStatus(generalTextLabel, generalSection, cv.generalParseResults, results)
|
||||
parseResultToStatus(fallbackTextLabel, fallbackSection, cv.fallbackParseResults, results)
|
||||
parseResultToStatus(privateTextLabel, privateSection, cv.privateParseResults, results)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// resultToStatus puts "OK" or an error message from res into resMap. section
|
||||
// is the name of the upstream configuration section, i.e. "general",
|
||||
// upstreamResultToStatus puts "OK" or an error message from res into resMap.
|
||||
// section is the name of the upstream configuration section, i.e. "general",
|
||||
// "fallback", or "private", and only used for logging.
|
||||
//
|
||||
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
|
||||
// put together in a single map, which may lead to collisions, see AG-27539.
|
||||
// Improve the results compilation.
|
||||
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
|
||||
func upstreamResultToStatus(
|
||||
section string,
|
||||
original string,
|
||||
res *upstreamResult,
|
||||
resMap map[string]string,
|
||||
) {
|
||||
val := "OK"
|
||||
if res.err != nil {
|
||||
val = res.err.Error()
|
||||
}
|
||||
|
||||
prevVal := resMap[res.original]
|
||||
prevVal := resMap[original]
|
||||
switch prevVal {
|
||||
case "":
|
||||
resMap[res.original] = val
|
||||
resMap[original] = val
|
||||
case val:
|
||||
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
|
||||
log.Debug("dnsforward: duplicating %s config line %q", section, original)
|
||||
default:
|
||||
log.Debug(
|
||||
"dnsforward: warning: %s config line %q (%v) had different result %v",
|
||||
section,
|
||||
val,
|
||||
res.original,
|
||||
original,
|
||||
prevVal,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// parseResultToStatus puts parsing error messages from results into resMap.
|
||||
// section is the name of the upstream configuration section, i.e. "general",
|
||||
// "fallback", or "private", and only used for logging.
|
||||
//
|
||||
// Parsing error message has the following format:
|
||||
//
|
||||
// sectionTextLabel line: parsing error
|
||||
//
|
||||
// Where sectionTextLabel is a section text label of a localization and line is
|
||||
// a line number.
|
||||
func parseResultToStatus(
|
||||
textLabel string,
|
||||
section string,
|
||||
results []*parseResult,
|
||||
resMap map[string]string,
|
||||
) {
|
||||
for _, res := range results {
|
||||
original := res.original
|
||||
_, ok := resMap[original]
|
||||
if ok {
|
||||
log.Debug("dnsforward: duplicating %s parsing error %q", section, original)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
resMap[original] = fmt.Sprintf("%s %d: parsing error", textLabel, res.err.Idx+1)
|
||||
}
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
@@ -342,7 +402,7 @@ func (h *healthchecker) check(u upstream.Upstream) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
} else if h.ansEmpty && len(reply.Answer) > 0 {
|
||||
return errWrongResponse
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -101,21 +100,6 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
type answerMap = map[uint16][sectionsNum][]dns.RR
|
||||
|
||||
pt := testutil.PanicT{}
|
||||
newUps := func(answers answerMap) (u upstream.Upstream) {
|
||||
return aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
q := req.Question[0]
|
||||
require.Contains(pt, answers, q.Qtype)
|
||||
|
||||
answer := answers[q.Qtype]
|
||||
|
||||
resp = (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = answer[sectionAnswer]
|
||||
resp.Ns = answer[sectionAuthority]
|
||||
resp.Extra = answer[sectionAdditional]
|
||||
|
||||
return resp, nil
|
||||
})
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -265,13 +249,16 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
}}
|
||||
|
||||
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
|
||||
localUps := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
require.Equal(pt, req.Question[0].Name, ptr64Domain)
|
||||
resp = (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = []dns.RR{localRR}
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
require.Len(pt, m.Question, 1)
|
||||
require.Equal(pt, m.Question[0].Name, ptr64Domain)
|
||||
resp := (&dns.Msg{
|
||||
Answer: []dns.RR{localRR},
|
||||
}).SetReply(m)
|
||||
|
||||
return resp, nil
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp",
|
||||
@@ -279,25 +266,44 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be reused
|
||||
// right after stop, due to a data race in [proxy.Proxy.Init] method
|
||||
// when setting an OOB size. As a temporary workaround, recreate the
|
||||
// whole server for each test case.
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, localUps)
|
||||
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)}
|
||||
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||
|
||||
answer := tc.upsAns[q.Qtype]
|
||||
|
||||
resp := (&dns.Msg{
|
||||
Answer: answer[sectionAnswer],
|
||||
Ns: answer[sectionAuthority],
|
||||
Extra: answer[sectionAdditional],
|
||||
}).SetReply(req)
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
upsAddr := aghtest.StartLocalhostUpstream(t, upsHdlr).String()
|
||||
|
||||
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be
|
||||
// reused right after stop, due to a data race in [proxy.Proxy.Init]
|
||||
// method when setting an OOB size. As a temporary workaround,
|
||||
// recreate the whole server for each test case.
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -30,7 +31,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
@@ -464,7 +464,8 @@ func (s *Server) Start() error {
|
||||
// startLocked starts the DNS server without locking. s.serverLock is expected
|
||||
// to be locked.
|
||||
func (s *Server) startLocked() error {
|
||||
err := s.dnsProxy.Start()
|
||||
// TODO(e.burkov): Use context properly.
|
||||
err := s.dnsProxy.Start(context.Background())
|
||||
if err == nil {
|
||||
s.isRunning = true
|
||||
}
|
||||
@@ -518,34 +519,30 @@ func (s *Server) prepareLocalResolvers(
|
||||
}
|
||||
|
||||
// setupLocalResolvers initializes and sets the resolvers for local addresses.
|
||||
// It assumes s.serverLock is locked or s not running.
|
||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
|
||||
uc, err := s.prepareLocalResolvers(boot)
|
||||
// It assumes s.serverLock is locked or s not running. It returns the upstream
|
||||
// configuration used for private PTR resolving, or nil if it's disabled. Note,
|
||||
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
uc, err = s.prepareLocalResolvers(boot)
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
},
|
||||
}
|
||||
|
||||
err = s.localResolvers.Init()
|
||||
s.localResolvers, err = proxy.New(&proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing proxy: %w", err)
|
||||
return nil, fmt.Errorf("creating local resolvers: %w", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
||||
if s.conf.UsePrivateRDNS &&
|
||||
// Only set the upstream config if there are any upstreams. It's safe
|
||||
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
|
||||
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
|
||||
}
|
||||
|
||||
return nil
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// Prepare initializes parameters of s using data from conf. conf must not be
|
||||
@@ -586,21 +583,22 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("preparing access: %w", err)
|
||||
}
|
||||
|
||||
// Set the proxy here because [setupLocalResolvers] sets its values.
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
|
||||
|
||||
err = s.setupLocalResolvers(boot)
|
||||
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
err = s.setupFallbackDNS()
|
||||
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
||||
}
|
||||
|
||||
s.dnsProxy, err = proxy.New(proxyConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating proxy: %w", err)
|
||||
}
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
@@ -643,26 +641,25 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
|
||||
}
|
||||
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
func (s *Server) setupFallbackDNS() (err error) {
|
||||
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
|
||||
fallbacks := s.conf.FallbackDNS
|
||||
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
||||
if len(fallbacks) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
|
||||
uc, err = proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
|
||||
// TODO(s.chzhen): Investigate if other options are needed.
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
// TODO(e.burkov): Use bootstrap.
|
||||
})
|
||||
if err != nil {
|
||||
// Do not wrap the error because it's informative enough as is.
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.dnsProxy.Fallbacks = uc
|
||||
|
||||
return nil
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// setupAddrProc initializes the address processor. It assumes s.serverLock is
|
||||
@@ -730,19 +727,9 @@ func (s *Server) prepareInternalProxy() (err error) {
|
||||
return fmt.Errorf("invalid upstream mode: %w", err)
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Make a proper constructor for proxy.Proxy.
|
||||
p := &proxy.Proxy{
|
||||
Config: *conf,
|
||||
}
|
||||
s.internalProxy, err = proxy.New(conf)
|
||||
|
||||
err = p.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.internalProxy = p
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop stops the DNS server.
|
||||
@@ -761,14 +748,17 @@ func (s *Server) stopLocked() (err error) {
|
||||
// [upstream.Upstream] implementations.
|
||||
|
||||
if s.dnsProxy != nil {
|
||||
err = s.dnsProxy.Stop()
|
||||
// TODO(e.burkov): Use context properly.
|
||||
err = s.dnsProxy.Shutdown(context.Background())
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing primary resolvers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||
if s.localResolvers != nil {
|
||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||
}
|
||||
|
||||
for _, b := range s.bootResolvers {
|
||||
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -63,8 +65,7 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
err := s.Start()
|
||||
require.NoErrorf(t, err, "failed to start server: %s", err)
|
||||
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
}
|
||||
|
||||
@@ -72,7 +73,6 @@ func createTestServer(
|
||||
t *testing.T,
|
||||
filterConf *filtering.Config,
|
||||
forwardConf ServerConfig,
|
||||
localUps upstream.Upstream,
|
||||
) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
@@ -82,7 +82,8 @@ func createTestServer(
|
||||
@@||whitelist.example.org^
|
||||
||127.0.0.255`
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
ID: 0,
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
@@ -105,19 +106,6 @@ func createTestServer(
|
||||
err = s.Prepare(&forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
// TODO(e.burkov): Try to move it higher.
|
||||
if localUps != nil {
|
||||
ups := []upstream.Upstream{localUps}
|
||||
s.localResolvers.UpstreamConfig.Upstreams = ups
|
||||
s.conf.UsePrivateRDNS = true
|
||||
s.dnsProxy.PrivateRDNSUpstreamConfig = &proxy.UpstreamConfig{
|
||||
Upstreams: ups,
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -181,7 +169,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
|
||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||
s.conf.TLSConfig = tlsConf
|
||||
@@ -310,7 +298,7 @@ func TestServer(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -410,7 +398,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -490,7 +478,7 @@ func TestServerRace(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -545,7 +533,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
@@ -628,7 +616,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
@@ -662,7 +650,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
@@ -698,7 +686,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
atomic.AddUint32(&upsCalledCounter, 1)
|
||||
@@ -773,7 +761,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
testUpstm := &aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
@@ -811,7 +799,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
@@ -886,7 +874,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
@@ -933,7 +921,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeNullIP,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
@@ -1054,7 +1042,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, forwardConf, nil)
|
||||
}, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
@@ -1102,7 +1090,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
@@ -1330,6 +1318,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||
OnStart: func() (_ error) { panic("not implemented") },
|
||||
OnEvents: func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
@@ -1481,6 +1470,8 @@ func TestServer_Exchange(t *testing.T) {
|
||||
onesIP = netip.MustParseAddr("1.1.1.1")
|
||||
twosIP = netip.MustParseAddr("2.2.2.2")
|
||||
localIP = netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
pt = testutil.PanicT{}
|
||||
)
|
||||
|
||||
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
@@ -1489,72 +1480,73 @@ func TestServer_Exchange(t *testing.T) {
|
||||
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
}
|
||||
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String()
|
||||
|
||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
}
|
||||
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
||||
refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
zeroTTLUps := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "zero.ttl.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = new(dns.Msg).SetReply(req)
|
||||
hdr := dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
}
|
||||
resp.Answer = []dns.RR{&dns.PTR{
|
||||
Hdr: hdr,
|
||||
Ptr: localDomainHost,
|
||||
}}
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure)))
|
||||
})
|
||||
|
||||
srv := &Server{
|
||||
recDetector: newRecursionDetector(0, 1),
|
||||
internalProxy: &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{extUpstream},
|
||||
nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
hash := sha256.Sum256([]byte("some-host"))
|
||||
resp := (&dns.Msg{
|
||||
Answer: []dns.RR{&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
require.NoError(t, srv.internalProxy.Init())
|
||||
Txt: []string{hex.EncodeToString(hash[:])},
|
||||
}},
|
||||
}).SetReply(req)
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)))
|
||||
})
|
||||
|
||||
zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{
|
||||
Answer: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Ptr: dns.Fqdn(localDomainHost),
|
||||
}},
|
||||
}).SetReply(req)
|
||||
|
||||
require.NoError(pt, w.WriteMsg(resp))
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
req netip.Addr
|
||||
wantErr error
|
||||
locUpstream upstream.Upstream
|
||||
locUpstream dns.Handler
|
||||
name string
|
||||
want string
|
||||
wantTTL time.Duration
|
||||
@@ -1569,35 +1561,35 @@ func TestServer_Exchange(t *testing.T) {
|
||||
name: "local_good",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
locUpstream: locUpsHdlr,
|
||||
req: localIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
wantErr: ErrRDNSFailed,
|
||||
locUpstream: errUpsHdlr,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "empty_answer_error",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: locUpstream,
|
||||
locUpstream: locUpsHdlr,
|
||||
req: netip.MustParseAddr("192.168.1.2"),
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "invalid_answer",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: nonPtrUpstream,
|
||||
locUpstream: nonPtrHdlr,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "refused",
|
||||
want: "",
|
||||
wantErr: ErrRDNSFailed,
|
||||
locUpstream: refusingUpstream,
|
||||
locUpstream: refusingHdlr,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
@@ -1611,23 +1603,28 @@ func TestServer_Exchange(t *testing.T) {
|
||||
name: "zero_ttl",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: zeroTTLUps,
|
||||
locUpstream: zeroTTLHdlr,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
pcfg := proxy.Config{
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{tc.locUpstream},
|
||||
},
|
||||
}
|
||||
srv.localResolvers = &proxy.Proxy{
|
||||
Config: pcfg,
|
||||
}
|
||||
require.NoError(t, srv.localResolvers.Init())
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String()
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
UsePrivateRDNS: true,
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
host, ttl, eerr := srv.Exchange(tc.req)
|
||||
|
||||
require.ErrorIs(t, eerr, tc.wantErr)
|
||||
@@ -1637,8 +1634,17 @@ func TestServer_Exchange(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("resolving_disabled", func(t *testing.T) {
|
||||
srv.conf.UsePrivateRDNS = false
|
||||
t.Cleanup(func() { srv.conf.UsePrivateRDNS = true })
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
LocalPTRResolvers: []string{},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
host, _, eerr := srv.Exchange(localIP)
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
|
||||
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
||||
return &dns.Msg{
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// jsonDNSConfig is the JSON representation of the DNS server configuration.
|
||||
@@ -294,7 +295,7 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ValidateUpstreams(*req.Fallbacks)
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback servers: %w", err)
|
||||
}
|
||||
@@ -344,7 +345,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("upstream servers: %w", err)
|
||||
}
|
||||
@@ -580,9 +581,6 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
|
||||
opts := &upstream.Options{
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
require.NoError(t, s.Start())
|
||||
@@ -164,7 +164,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
defaultConf := s.conf
|
||||
@@ -223,8 +223,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating dns config: ` +
|
||||
`upstream servers: validating upstream "!!!": not an ip:port`,
|
||||
wantSet: `validating dns config: upstream servers: parsing error at index 0: ` +
|
||||
`cannot prepare the upstream: invalid address !!!: bad hostname "!!!": ` +
|
||||
`bad top-level domain name label "!!!": bad top-level domain name label rune '!'`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
wantSet: `validating dns config: checking bootstrap a: not a bootstrap: ParseAddr("a"): ` +
|
||||
@@ -313,98 +314,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
|
||||
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
|
||||
`uYWRndWFyZC5jb20`
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
set []string
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErr: ``,
|
||||
set: nil,
|
||||
}, {
|
||||
name: "comment",
|
||||
wantErr: ``,
|
||||
set: []string{"# comment"},
|
||||
}, {
|
||||
name: "no_default",
|
||||
wantErr: `no default upstreams specified`,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]" + sdnsStamp,
|
||||
},
|
||||
}, {
|
||||
name: "with_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "dhcp://fake.dns": bad protocol "dhcp"`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "1.2.3.4.5": not an ip:port`,
|
||||
set: []string{"1.2.3.4.5"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "123.3.7m": not an ip:port`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
|
||||
`missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
|
||||
set: []string{"[host.ru]#"},
|
||||
}, {
|
||||
name: "valid_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
sdnsStamp,
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad top-level domain name label "!": ` +
|
||||
`bad top-level domain name label rune '!'`,
|
||||
set: []string{"[/!/]8.8.8.8"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreams(tc.set)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
@@ -509,6 +418,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
},
|
||||
},
|
||||
&aghtest.FSWatcher{
|
||||
OnStart: func() (_ error) { panic("not implemented") },
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(_ string) (err error) { return nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
@@ -529,7 +439,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
srv.etcHosts = upstream.NewHostsResolver(hc)
|
||||
startDeferStop(t, srv)
|
||||
|
||||
|
||||
@@ -2,13 +2,13 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -87,7 +86,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, c, nil)
|
||||
}, c)
|
||||
|
||||
var gotAddr netip.Addr
|
||||
s.addrProc = &aghtest.AddressProcessor{
|
||||
@@ -188,7 +187,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, c, nil)
|
||||
}, c)
|
||||
|
||||
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
|
||||
dctx := &dnsContext{
|
||||
@@ -248,9 +247,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host string
|
||||
want []*dns.SVCB
|
||||
wantRes resultCode
|
||||
portDoH int
|
||||
portDoT int
|
||||
portDoQ int
|
||||
addrsDoH []*net.TCPAddr
|
||||
addrsDoT []*net.TCPAddr
|
||||
addrsDoQ []*net.UDPAddr
|
||||
qtype uint16
|
||||
ddrEnabled bool
|
||||
}{{
|
||||
@@ -259,14 +258,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: testQuestionTarget,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||
}, {
|
||||
name: "pass_qtype",
|
||||
wantRes: resultCodeFinish,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeA,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||
}, {
|
||||
name: "pass_disabled_tls",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -279,7 +278,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: false,
|
||||
portDoH: 8043,
|
||||
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||
}, {
|
||||
name: "dot",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -287,7 +286,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
addrsDoT: []*net.TCPAddr{{Port: 8043}},
|
||||
}, {
|
||||
name: "doh",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -295,7 +294,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8044,
|
||||
addrsDoH: []*net.TCPAddr{{Port: 8044}},
|
||||
}, {
|
||||
name: "doq",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -303,7 +302,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoQ: 8042,
|
||||
addrsDoQ: []*net.UDPAddr{{Port: 8042}},
|
||||
}, {
|
||||
name: "dot_doh",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -311,13 +310,35 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
portDoH: 8044,
|
||||
addrsDoT: []*net.TCPAddr{{Port: 8043}},
|
||||
addrsDoH: []*net.TCPAddr{{Port: 8044}},
|
||||
}}
|
||||
|
||||
_, certPem, keyPem := createServerTLSConfig(t)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
Config: Config{
|
||||
HandleDDR: tc.ddrEnabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
CertificateChainData: certPem,
|
||||
PrivateKeyData: keyPem,
|
||||
TLSListenAddrs: tc.addrsDoT,
|
||||
HTTPSListenAddrs: tc.addrsDoH,
|
||||
QUICListenAddrs: tc.addrsDoQ,
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
// TODO(e.burkov): Generate a certificate actually containing the
|
||||
// IP addresses.
|
||||
s.conf.hasIPAddrs = true
|
||||
|
||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||
|
||||
@@ -358,41 +379,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
|
||||
return f
|
||||
}
|
||||
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
s = &Server{
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxy.Config{},
|
||||
},
|
||||
conf: ServerConfig{
|
||||
Config: Config{
|
||||
HandleDDR: ddrEnabled,
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
}
|
||||
|
||||
if portDoT > 0 {
|
||||
s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
s.conf.hasIPAddrs = true
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
if portDoH > 0 {
|
||||
s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
@@ -680,13 +666,16 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
@@ -696,12 +685,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||
// Improve Config declaration for tests.
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, ups)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -764,6 +755,16 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
const locDomain = "some.local."
|
||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{
|
||||
@@ -776,14 +777,10 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
}),
|
||||
)
|
||||
|
||||
var proxyCtx *proxy.DNSContext
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}, nil)
|
||||
})
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
|
||||
@@ -2,10 +2,9 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@@ -16,29 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
// errNotDomainSpecific is returned when the upstream should be
|
||||
// domain-specific, but isn't.
|
||||
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
||||
|
||||
// errMissingSeparator is returned when the domain-specific part of the
|
||||
// upstream configuration line isn't closed.
|
||||
errMissingSeparator errors.Error = "missing separator"
|
||||
|
||||
// errDupSeparator is returned when the domain-specific part of the upstream
|
||||
// configuration line contains more than one ending separator.
|
||||
errDupSeparator errors.Error = "duplicated separator"
|
||||
|
||||
// errNoDefaultUpstreams is returned when there are no default upstreams
|
||||
// specified in the upstream configuration.
|
||||
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
||||
|
||||
// errWrongResponse is returned when the checked upstream replies in an
|
||||
// unexpected way.
|
||||
errWrongResponse errors.Error = "wrong response"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
@@ -199,84 +175,12 @@ func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: net.DefaultResolver,
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errNoDefaultUpstreams
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var isSpecific bool
|
||||
ups, isSpecific, err = splitUpstreamLine(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, isSpecific)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating config: %w", err)
|
||||
}
|
||||
@@ -308,66 +212,3 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||
|
||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||
}
|
||||
|
||||
// protocols are the supported URL schemes for upstreams.
|
||||
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = u == "#" && isSpecific; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||
if !slices.Contains(protocols, proto) {
|
||||
return false, fmt.Errorf("bad protocol %q", proto)
|
||||
}
|
||||
} else if _, err = netip.ParseAddr(u); err == nil {
|
||||
return false, nil
|
||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// splitUpstreamLine returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return []string{upstreamStr}, false, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
|
||||
|
||||
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
|
||||
if !found {
|
||||
return nil, false, errMissingSeparator
|
||||
} else if strings.Contains(ups, "/]") {
|
||||
return nil, false, errDupSeparator
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(doms, "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
isSpecific = true
|
||||
}
|
||||
|
||||
return strings.Fields(ups), isSpecific, nil
|
||||
}
|
||||
|
||||
@@ -100,8 +100,7 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||
name: "bad_specification",
|
||||
general: []string{"[/domain.example/]/]1.2.3.4"},
|
||||
want: map[string]string{
|
||||
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
|
||||
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
|
||||
"[/domain.example/]/]1.2.3.4": generalTextLabel + " 1: parsing error",
|
||||
},
|
||||
}, {
|
||||
name: "all_different",
|
||||
@@ -120,23 +119,9 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||
fallback: []string{"[/example/" + goodUps},
|
||||
private: []string{"[/example//bad.123/]" + goodUps},
|
||||
want: map[string]string{
|
||||
`[/example/]/]` + goodUps: `splitting upstream line ` +
|
||||
`"[/example/]/]` + goodUps + `": duplicated separator`,
|
||||
`[/example/` + goodUps: `splitting upstream line ` +
|
||||
`"[/example/` + goodUps + `": missing separator`,
|
||||
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
|
||||
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
|
||||
`bad domain name "bad.123": ` +
|
||||
`bad top-level domain name label "123": all octets are numeric`,
|
||||
},
|
||||
}, {
|
||||
name: "non-specific_default",
|
||||
general: []string{
|
||||
"#",
|
||||
"[/example/]#",
|
||||
},
|
||||
want: map[string]string{
|
||||
"#": "not a domain-specific upstream",
|
||||
"[/example/]/]" + goodUps: generalTextLabel + " 1: parsing error",
|
||||
"[/example/" + goodUps: fallbackTextLabel + " 1: parsing error",
|
||||
"[/example//bad.123/]" + goodUps: privateTextLabel + " 1: parsing error",
|
||||
},
|
||||
}, {
|
||||
name: "bad_proto",
|
||||
@@ -144,7 +129,15 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||
"bad://1.2.3.4",
|
||||
},
|
||||
want: map[string]string{
|
||||
"bad://1.2.3.4": `bad protocol "bad"`,
|
||||
"bad://1.2.3.4": generalTextLabel + " 1: parsing error",
|
||||
},
|
||||
}, {
|
||||
name: "truncated_line",
|
||||
general: []string{
|
||||
"This is a very long line. It will cause a parsing error and will be truncated here.",
|
||||
},
|
||||
want: map[string]string{
|
||||
"This is a very long line. It will cause a parsing error and will be truncated …": "upstream_dns 1: parsing error",
|
||||
},
|
||||
}}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user