Pull request: all: add srv handling to dnsrewrite filters

Closes #2498.
Updates #2533.

Squashed commit of the following:

commit 452e0e7d281c1f10bef069bf7a73205266b8f1e0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Mar 12 19:33:18 2021 +0300

    all: add srv handling to dnsrewrite filters
This commit is contained in:
Ainar Garipov
2021-03-15 13:08:13 +03:00
parent 4c6bf68d4d
commit d970b79f2b
6 changed files with 72 additions and 77 deletions

View File

@@ -20,7 +20,8 @@ func (s *Server) filterDNSRewriteResponse(req *dns.Msg, rr rules.RRType, v rules
// the answer generation logic from the Server.
switch rr {
case dns.TypeA, dns.TypeAAAA:
case dns.TypeA,
dns.TypeAAAA:
ip, ok := v.(net.IP)
if !ok {
return nil, fmt.Errorf("value for rr type %d has type %T, not net.IP", rr, v)
@@ -62,6 +63,13 @@ func (s *Server) filterDNSRewriteResponse(req *dns.Msg, rr rules.RRType, v rules
}
return s.genAnswerSVCB(req, svcb), nil
case dns.TypeSRV:
srv, ok := v.(*rules.DNSSRV)
if !ok {
return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSSRV", rr, v)
}
return s.genAnswerSRV(req, srv), nil
default:
log.Debug("don't know how to handle dns rr type %d, skipping", rr)

View File

@@ -14,18 +14,24 @@ import (
func TestServer_FilterDNSRewrite(t *testing.T) {
// Helper data.
const domain = "example.com"
ip4 := net.IP{127, 0, 0, 1}
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mx := &rules.DNSMX{
mxVal := &rules.DNSMX{
Exchange: "mail.example.com",
Preference: 32,
}
svcb := &rules.DNSSVCB{
svcbVal := &rules.DNSSVCB{
Params: map[string]string{"alpn": "h3"},
Target: "example.com",
Target: domain,
Priority: 32,
}
const domain = "example.com"
srvVal := &rules.DNSSRV{
Priority: 32,
Weight: 60,
Port: 8080,
Target: domain,
}
// Helper functions and entities.
srv := &Server{}
@@ -125,7 +131,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
t.Run("noerror_mx", func(t *testing.T) {
req := makeQ(dns.TypeMX)
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mx)
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mxVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
@@ -136,13 +142,13 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
ans, ok := d.Res.Answer[0].(*dns.MX)
require.True(t, ok)
assert.Equal(t, mx.Exchange, ans.Mx)
assert.Equal(t, mx.Preference, ans.Preference)
assert.Equal(t, mxVal.Exchange, ans.Mx)
assert.Equal(t, mxVal.Preference, ans.Preference)
})
t.Run("noerror_svcb", func(t *testing.T) {
req := makeQ(dns.TypeSVCB)
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcb)
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcbVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
@@ -154,14 +160,14 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcb.Target, ans.Target)
assert.Equal(t, svcb.Priority, ans.Priority)
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})
t.Run("noerror_https", func(t *testing.T) {
req := makeQ(dns.TypeHTTPS)
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcb)
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcbVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
@@ -173,8 +179,27 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcb.Target, ans.Target)
assert.Equal(t, svcb.Priority, ans.Priority)
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})
t.Run("noerror_srv", func(t *testing.T) {
req := makeQ(dns.TypeSRV)
res := makeRes(dns.RcodeSuccess, dns.TypeSRV, srvVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
ans, ok := d.Res.Answer[0].(*dns.SRV)
require.True(t, ok)
assert.Equal(t, srvVal.Priority, ans.Priority)
assert.Equal(t, srvVal.Weight, ans.Weight)
assert.Equal(t, srvVal.Port, ans.Port)
assert.Equal(t, srvVal.Target, ans.Target)
})
}

View File

@@ -136,7 +136,7 @@ func (s *Server) genAnswerCNAME(req *dns.Msg, cname string) (ans *dns.CNAME) {
func (s *Server) genAnswerMX(req *dns.Msg, mx *rules.DNSMX) (ans *dns.MX) {
return &dns.MX{
Hdr: s.hdr(req, dns.TypePTR),
Hdr: s.hdr(req, dns.TypeMX),
Preference: mx.Preference,
Mx: mx.Exchange,
}
@@ -149,6 +149,16 @@ func (s *Server) genAnswerPTR(req *dns.Msg, ptr string) (ans *dns.PTR) {
}
}
func (s *Server) genAnswerSRV(req *dns.Msg, srv *rules.DNSSRV) (ans *dns.SRV) {
return &dns.SRV{
Hdr: s.hdr(req, dns.TypeSRV),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: srv.Target,
}
}
func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
return &dns.TXT{
Hdr: s.hdr(req, dns.TypeTXT),