all: sync with master; upd chlog
This commit is contained in:
@@ -614,6 +614,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
// setupFallbackDNS initializes the fallback DNS servers.
|
||||
func (s *Server) setupFallbackDNS() (err error) {
|
||||
fallbacks := s.conf.FallbackDNS
|
||||
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
||||
if len(fallbacks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -372,6 +372,27 @@ func TestServer_timeout(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
Config: Config{
|
||||
FallbackDNS: []string{
|
||||
"#tls://1.1.1.1",
|
||||
"8.8.8.8",
|
||||
},
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, s.dnsProxy.Fallbacks)
|
||||
|
||||
assert.Len(t, s.dnsProxy.Fallbacks.Upstreams, 1)
|
||||
}
|
||||
|
||||
func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@@ -140,15 +141,15 @@ func (s *Server) filterRewritten(
|
||||
|
||||
// checkHostRules checks the host against filters. It is safe for concurrent
|
||||
// use.
|
||||
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
|
||||
r *filtering.Result,
|
||||
err error,
|
||||
) {
|
||||
func (s *Server) checkHostRules(
|
||||
host string,
|
||||
rrtype rules.RRType,
|
||||
setts *filtering.Settings,
|
||||
) (r *filtering.Result, err error) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
var res filtering.Result
|
||||
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||
res, err := s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -156,20 +157,21 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// filterDNSResponse checks each resource record of the response's answer
|
||||
// section from pctx and returns a non-nil res if at least one of canonical
|
||||
// names or IP addresses in it matches the filtering rules.
|
||||
func (s *Server) filterDNSResponse(
|
||||
pctx *proxy.DNSContext,
|
||||
setts *filtering.Settings,
|
||||
) (res *filtering.Result, err error) {
|
||||
// filterDNSResponse checks each resource record of answer section of
|
||||
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
|
||||
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
|
||||
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
|
||||
func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
|
||||
setts := dctx.setts
|
||||
if !setts.FilteringEnabled {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, a := range pctx.Res.Answer {
|
||||
var res *filtering.Result
|
||||
pctx := dctx.proxyCtx
|
||||
for i, a := range pctx.Res.Answer {
|
||||
host := ""
|
||||
var rrtype uint16
|
||||
var rrtype rules.RRType
|
||||
switch a := a.(type) {
|
||||
case *dns.CNAME:
|
||||
host = strings.TrimSuffix(a.Target, ".")
|
||||
@@ -195,18 +197,19 @@ func (s *Server) filterDNSResponse(
|
||||
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil {
|
||||
continue
|
||||
} else if res.IsFiltered {
|
||||
return fmt.Errorf("filtering answer at index %d: %w", i, err)
|
||||
} else if res != nil && res.IsFiltered {
|
||||
dctx.result = res
|
||||
dctx.origResp = pctx.Res
|
||||
pctx.Res = s.genDNSFilterMessage(pctx, res)
|
||||
|
||||
log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)
|
||||
|
||||
return res, nil
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeIPv6Hints deletes IPv6 hints from RR values.
|
||||
|
||||
@@ -328,26 +328,34 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
}
|
||||
|
||||
res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
})
|
||||
require.NoError(t, rErr)
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
setts: &filtering.Settings{
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
fltErr := s.filterDNSResponse(dctx)
|
||||
require.NoError(t, fltErr)
|
||||
|
||||
res := dctx.result
|
||||
if tc.wantRule == "" {
|
||||
assert.Nil(t, res)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
want := &filtering.Result{
|
||||
wantResult := &filtering.Result{
|
||||
IsFiltered: true,
|
||||
Reason: filtering.FilteredBlockList,
|
||||
Rules: []*filtering.ResultRule{{
|
||||
Text: tc.wantRule,
|
||||
}},
|
||||
}
|
||||
assert.Equal(t, want, res)
|
||||
|
||||
assert.Equal(t, wantResult, res)
|
||||
assert.Equal(t, resp, dctx.origResp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -749,6 +749,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
|
||||
|
||||
@@ -577,6 +577,14 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "fallback_broken",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"fallback_dns": []string{goodUps, "#this.is.comment"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
},
|
||||
name: "fallback_comment_mix",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -671,11 +671,11 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
|
||||
if ctx.proxyCtx.Res != nil {
|
||||
if dctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -684,8 +684,8 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
var err error
|
||||
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
|
||||
ctx.err = err
|
||||
if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
@@ -857,7 +857,6 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
log.Debug("dnsforward: started processing filtering after resp")
|
||||
defer log.Debug("dnsforward: finished processing filtering after resp")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
switch res := dctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
return resultCodeSuccess
|
||||
@@ -871,6 +870,7 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
|
||||
if len(pctx.Res.Answer) > 0 {
|
||||
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
|
||||
@@ -880,13 +880,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
|
||||
return resultCodeSuccess
|
||||
default:
|
||||
return s.filterAfterResponse(dctx, pctx)
|
||||
return s.filterAfterResponse(dctx)
|
||||
}
|
||||
}
|
||||
|
||||
// filterAfterResponse returns the result of filtering the response that wasn't
|
||||
// explicitly allowed or rewritten.
|
||||
func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (res resultCode) {
|
||||
func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
@@ -894,17 +894,12 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
result, err := s.filterDNSResponse(pctx, dctx.setts)
|
||||
err := s.filterDNSResponse(dctx)
|
||||
if err != nil {
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
dctx.result = result
|
||||
dctx.origResp = pctx.Res
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user