Pull request: 1558 enable dnsrewrites on disabled protection

Merge in DNS/adguard-home from 1558-always-rewrite to master

Squashed commit of the following:

commit b8508b3b5fb688cad273a9259c09ccfc07948b2f
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 20 19:17:22 2021 +0300

    all: imp log of changes

commit 97e3649b670786a2936e368a9505faf52f8e8804
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Oct 18 13:18:15 2021 +0300

    all: enable dnsrewrites on disabled protection
This commit is contained in:
Eugene Burkov
2021-10-20 19:52:13 +03:00
parent d7aafa7dc6
commit b8c4651dec
10 changed files with 178 additions and 132 deletions

View File

@@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processRestrictLocal,
s.processInternalIPAddrs,
s.processClientID,
processFilteringBeforeRequest,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
processDNSSECAfterResponse,
@@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
if ctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if !ctx.protectionEnabled {
ctx.protectionEnabled = s.conf.ProtectionEnabled
if s.dnsFilter == nil {
return resultCodeSuccess
}
@@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
}
var err error
ctx.result, err = s.filterDNSRequest(ctx)
if err != nil {
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
return resultCodeError
@@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason {
case filtering.Rewritten,
switch res := ctx.result; res.Reason {
case filtering.NotFilteredAllowList:
// Go on.
case
filtering.Rewritten,
filtering.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table
// origQuestion is set in case we get only CNAME without IP from
// rewrites table.
break
}
d.Req.Question[0] = ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
if len(d.Res.Answer) > 0 {
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
d.Res.Answer = answer
}
case filtering.NotFilteredAllowList:
// nothing
default:
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for
!ctx.responseFromUpstream { // only check response if it's from an upstream server
// Check the response only if the it's from an upstream. Don't check
// the response if the protection is disabled since dnsrewrite rules
// aren't applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break
}
origResp2 := d.Res
ctx.result, err = s.filterDNSResponse(ctx)
origResp := d.Res
result, err := s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
} else {
ctx.result = &filtering.Result{}
if result != nil {
ctx.result = result
ctx.origResp = origResp
}
}
if ctx.result == nil {
ctx.result = &filtering.Result{}
}
return resultCodeSuccess
}

View File

@@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) {
}},
}
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
@@ -945,45 +946,56 @@ func TestRewrite(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 1)
require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
// The original question is restored.
require.Len(t, reply.Question, 1)
// The original question is restored.
require.Len(t, reply.Question, 1)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func publicKey(priv interface{}) interface{} {
@@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) {
require.ErrorIs(t, hc.Close(), closeCalled)
})
c := filtering.Config{
flt := filtering.New(&filtering.Config{
EtcHosts: hc,
}
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
@@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: filtering.New(&c, nil),
DNSFilter: flt,
SubnetDetector: snd,
})
require.NoError(t, err)
@@ -1112,32 +1125,41 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err = s.Prepare(nil)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(func() {
s.Close()
})
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Lenf(t, resp.Answer, 1, "%#v", resp)
require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func TestNewServer(t *testing.T) {

View File

@@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
setts.ProtectionEnabled = ctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts)
@@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
d := ctx.proxyCtx
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts)
if err != nil {
// Return immediately if there's an error
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err)
} else if res.IsFiltered {
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".")
res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
switch {
case err != nil:
return nil, fmt.Errorf("failed to check host %q: %w", host, err)
case res.IsFiltered:
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0 {
// Resolve the new canonical name, not the original host
// name. The original question is readded in
// processFilteringAfterResponse.
ctx.origQuestion = req.Question[0]
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse.
ctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0:
resp := s.makeResponse(req)
hdr := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr := &dns.PTR{
Hdr: hdr,
Ptr: h,
@@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
d.Res = resp
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) {
case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts):
resp := s.makeResponse(req)
name := host
@@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA {
switch q.Qtype {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA {
case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
@@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
d.Res = resp
} else if res.Reason == filtering.RewrittenRule {
err = s.filterDNSRewrite(req, res, d)
if err != nil {
case res.Reason == filtering.RewrittenRule:
if err = s.filterDNSRewrite(req, res, d); err != nil {
return nil, err
}
}
@@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
continue
}
host = strings.TrimSuffix(host, ".")
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil {
return nil, err