Merge branch 'master' into 4403-internal-proxy

This commit is contained in:
Ainar Garipov
2022-08-29 15:27:44 +03:00
241 changed files with 15591 additions and 5972 deletions

View File

@@ -214,7 +214,7 @@ func validateAccessSet(list *accessListJSON) (err error) {
}
merged := allowed.Merge(disallowed)
err = merged.Validate(aghalg.StringIsBefore)
err = merged.Validate()
if err != nil {
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
}
@@ -223,13 +223,13 @@ func validateAccessSet(list *accessListJSON) (err error) {
}
// validateStrUniq returns an informative error if clients are not unique.
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
uc = make(aghalg.UniqChecker, len(clients))
func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error) {
uc = make(aghalg.UniqChecker[string], len(clients))
for _, c := range clients {
uc.Add(c)
}
return uc, uc.Validate(aghalg.StringIsBefore)
return uc, uc.Validate()
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {

View File

@@ -65,7 +65,7 @@ func clientIDFromClientServerName(
return "", err
}
return clientID, nil
return strings.ToLower(clientID), nil
}
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
@@ -104,7 +104,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
return strings.ToLower(clientID), nil
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
@@ -112,8 +112,8 @@ type tlsConn interface {
ConnectionState() (cs tls.ConnectionState)
}
// quicSession is a narrow interface for quic.Session to simplify testing.
type quicSession interface {
// quicConnection is a narrow interface for quic.Connection to simplify testing.
type quicConnection interface {
ConnectionState() (cs quic.ConnectionState)
}
@@ -148,16 +148,16 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
cliSrvName = tc.ConnectionState().ServerName
case proxy.ProtoQUIC:
qs, ok := pctx.QUICSession.(quicSession)
conn, ok := pctx.QUICConnection.(quicConnection)
if !ok {
return "", fmt.Errorf(
"proxy ctx quic session of proto %s is %T, want quic.Session",
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
proto,
pctx.QUICSession,
pctx.QUICConnection,
)
}
cliSrvName = qs.ConnectionState().TLS.ServerName
cliSrvName = conn.ConnectionState().TLS.ServerName
}
clientID, err = clientIDFromClientServerName(

View File

@@ -29,17 +29,18 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
return cs
}
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession a quic.Session
// without actually implementing all methods.
quic.Session
// testQUICConnection is a quicConnection for tests.
type testQUICConnection struct {
// Connection is embedded here simply to make testQUICConnection a
// quic.Connection without actually implementing all methods.
quic.Connection
serverName string
}
// ConnectionState implements the quicSession interface for testQUICSession.
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
// ConnectionState implements the quicConnection interface for
// testQUICConnection.
func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
cs.TLS.ServerName = c.serverName
return cs
@@ -143,6 +144,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}, {
name: "tls_case",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "InSeNsItIvE.example.com",
wantClientID: "insensitive",
wantErrMsg: ``,
strictSNI: true,
}, {
name: "quic_case",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "InSeNsItIvE.example.com",
wantClientID: "insensitive",
wantErrMsg: ``,
strictSNI: true,
}}
for _, tc := range testCases {
@@ -163,17 +180,17 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
}
}
var qs quic.Session
var qconn quic.Connection
if tc.proto == proxy.ProtoQUIC {
qs = testQUICSession{
qconn = testQUICConnection{
serverName: tc.cliSrvName,
}
}
pctx := &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
Proto: tc.proto,
Conn: conn,
QUICConnection: qconn,
}
clientID, err := srv.clientIDFromDNSContext(pctx)
@@ -210,6 +227,11 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
}, {
name: "clientid_case",
path: "/dns-query/InSeNsItIvE",
wantClientID: "insensitive",
wantErrMsg: ``,
}, {
name: "bad_url",
path: "/foo",

View File

@@ -5,12 +5,12 @@ import (
"crypto/x509"
"fmt"
"net"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -122,6 +122,7 @@ type FilteringConfig struct {
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests
// IpsetList is the ipset configuration that allows AdGuard Home to add
// IP addresses of the specified domain names to an ipset list. Syntax:
@@ -133,8 +134,9 @@ type FilteringConfig struct {
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
type TLSConfig struct {
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
@@ -151,7 +153,7 @@ type TLSConfig struct {
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only being
// used for ClientID checking.
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
@@ -191,7 +193,7 @@ type ServerConfig struct {
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
HTTPRegister aghhttp.RegisterFunc
// ResolveClients signals if the RDNS should resolve clients' addresses.
ResolveClients bool
@@ -276,6 +278,11 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
return proxyConfig, nil
}
const (
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
// initDefaultSettings initializes default settings if nothing
// is configured
func (s *Server) initDefaultSettings() {
@@ -287,12 +294,12 @@ func (s *Server) initDefaultSettings() {
s.conf.BootstrapDNS = defaultBootstrap
}
if len(s.conf.ParentalBlockHost) == 0 {
s.conf.ParentalBlockHost = parentalBlockHost
if s.conf.ParentalBlockHost == "" {
s.conf.ParentalBlockHost = defaultParentalBlockHost
}
if len(s.conf.SafeBrowsingBlockHost) == 0 {
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
if s.conf.SafeBrowsingBlockHost == "" {
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
}
if s.conf.UDPListenAddrs == nil {

View File

@@ -76,6 +76,10 @@ const (
resultCodeError
)
// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests.
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
const ddrHostFQDN = "_dns.resolver.arpa."
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{
@@ -94,10 +98,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
mods := []modProcessFunc{
s.processRecursion,
s.processInitial,
s.processDDRQuery,
s.processDetermineLocal,
s.processInternalHosts,
s.processDHCPHosts,
s.processRestrictLocal,
s.processInternalIPAddrs,
s.processDHCPAddrs,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
@@ -135,7 +140,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
@@ -226,12 +230,10 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
)
}
lowhost := strings.ToLower(l.Hostname)
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
ip := netutil.CloneIP(l.IP)
ipToHost.Set(l.IP, lowhost)
ip := make(net.IP, 4)
copy(ip, l.IP.To4())
ipToHost.Set(ip, lowhost)
hostToIP[lowhost] = ip
}
@@ -242,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
s.setTableIPToHost(ipToHost)
}
// processDDRQuery responds to SVCB query for a special use domain name
// _dns.resolver.arpa. The result contains different types of encryption
// supported by current user configuration.
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
question := d.Req.Question[0]
if !s.conf.HandleDDR {
return resultCodeSuccess
}
if question.Name == ddrHostFQDN {
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
d.Res = s.makeResponse(d.Req)
return resultCodeFinish
}
d.Res = s.makeDDRResponse(d.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// makeDDRResponse creates DDR answer according to server configuration. The
// contructed SVCB resource records have the priority of 1 for each entry,
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
//
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req)
// TODO(e.burkov): Think about storing the FQDN version of the server's
// name somewhere.
domainName := dns.Fqdn(s.conf.ServerName)
for _, addr := range s.conf.HTTPSListenAddrs {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.TLSListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.QUICListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
return resp
}
// processDetermineLocal determines if the client's IP address is from
// locally-served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
@@ -252,7 +346,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
return rc
}
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
dctx.isLocalClient = s.privateNets.Contains(ip)
return rc
}
@@ -280,11 +374,11 @@ func (s *Server) hostToIP(host string) (ip net.IP, ok bool) {
return ip, true
}
// processInternalHosts respond to A requests if the target hostname is known to
// processDHCPHosts respond to A requests if the target hostname is known to
// the server.
//
// TODO(a.garipov): Adapt to AAAA as well.
func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
if !s.dhcpServer.Enabled() {
return resultCodeSuccess
}
@@ -299,11 +393,10 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
reqHost := strings.ToLower(q.Name)
reqHost := strings.ToLower(q.Name[:len(q.Name)-1])
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
// server.
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
if host == reqHost {
if !strings.HasSuffix(reqHost, s.localDomainSuffix) {
return resultCodeSuccess
}
@@ -316,7 +409,7 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
ip, ok := s.hostToIP(host)
ip, ok := s.hostToIP(reqHost)
if !ok {
// TODO(e.burkov): Inspect special cases when user want to apply some
// rules handled by other processors to the hosts with TLD.
@@ -373,8 +466,8 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at least
// don't need to be inaccessible externally.
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
// don't need to be accessible externally.
if !s.privateNets.Contains(ip) {
log.Debug("dns: addr %s is not from locally-served network", ip)
return resultCodeSuccess
@@ -413,7 +506,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
return "", false
}
var v interface{}
var v any
v, ok = s.tableIPToHost.Get(ip)
if !ok {
return "", false
@@ -430,7 +523,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
// Respond to PTR requests if the target IP is leased by our DHCP server and the
// requestor is inside the local network.
func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
func (s *Server) processDHCPAddrs(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess
@@ -481,7 +574,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
if !s.privateNets.Contains(ip) {
return resultCodeSuccess
}

View File

@@ -4,35 +4,212 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_ProcessDetermineLocal(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
s := &Server{
subnetDetector: snd,
const (
ddrTestDomainName = "dns.example.net"
ddrTestFQDN = ddrTestDomainName + "."
)
func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: 8044},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
},
}
dotSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: 8043},
},
}
doqSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: 8042},
},
}
testCases := []struct {
name string
host string
want []*dns.SVCB
wantRes resultCode
portDoH int
portDoT int
portDoQ int
qtype uint16
ddrEnabled bool
}{{
name: "pass_host",
wantRes: resultCodeSuccess,
host: "example.net.",
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_qtype",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeA,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_disabled_tls",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
}, {
name: "pass_disabled_ddr",
wantRes: resultCodeSuccess,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: false,
portDoH: 8043,
}, {
name: "dot",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
}, {
name: "doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8044,
}, {
name: "doq",
wantRes: resultCodeFinish,
want: []*dns.SVCB{doqSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoQ: 8042,
}, {
name: "dot_doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB, dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
portDoH: 8044,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
req := createTestMessageWithType(tc.host, tc.qtype)
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
}
res := s.processDDRQuery(dctx)
require.Equal(t, tc.wantRes, res)
if tc.wantRes != resultCodeFinish {
return
}
msg := dctx.proxyCtx.Res
require.NotNil(t, msg)
for _, v := range tc.want {
v.Hdr = s.hdr(req, dns.TypeSVCB)
}
assert.ElementsMatch(t, tc.want, msg.Answer)
})
}
}
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
proxyConf := proxy.Config{}
if portDoT > 0 {
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
}
if portDoQ > 0 {
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
}
s = &Server{
dnsProxy: &proxy.Proxy{
Config: proxyConf,
},
conf: ServerConfig{
FilteringConfig: FilteringConfig{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
},
}
if portDoH > 0 {
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
}
return s
}
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliIP net.IP
want bool
}{{
want: assert.True,
name: "local",
cliIP: net.IP{192, 168, 0, 1},
want: true,
}, {
want: assert.False,
name: "external",
cliIP: net.IP{250, 249, 0, 1},
want: false,
}, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
}}
for _, tc := range testCases {
@@ -47,12 +224,12 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
s.processDetermineLocal(dctx)
assert.Equal(t, tc.want, dctx.isLocalClient)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
testCases := []struct {
@@ -93,7 +270,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
dhcpServer: &testDHCP{},
localDomainSuffix: defaultLocalDomainSuffix,
tableHostToIP: hostToIPTable{
"example": knownIP,
"example." + defaultLocalDomainSuffix: knownIP,
},
}
@@ -115,7 +292,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
isLocalClient: tc.isLocalCli,
}
res := s.processInternalHosts(dctx)
res := s.processDHCPHosts(dctx)
require.Equal(t, tc.wantRes, res)
pctx := dctx.proxyCtx
if tc.wantRes == resultCodeFinish {
@@ -141,10 +318,10 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
}
}
func TestServer_ProcessInternalHosts(t *testing.T) {
func TestServer_ProcessDHCPHosts(t *testing.T) {
const (
examplecom = "example.com"
examplelan = "example.lan"
examplelan = "example." + defaultLocalDomainSuffix
)
knownIP := net.IP{1, 2, 3, 4}
@@ -193,41 +370,41 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
}, {
name: "success_custom_suffix",
host: "example.custom",
suffix: ".custom.",
suffix: "custom",
wantIP: knownIP,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}}
for _, tc := range testCases {
s := &Server{
dhcpServer: &testDHCP{},
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example." + tc.suffix: knownIP,
},
}
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: dns.Fqdn(tc.host),
Qtype: tc.qtyp,
Qclass: dns.ClassINET,
}},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
isLocalClient: true,
}
t.Run(tc.name, func(t *testing.T) {
s := &Server{
dhcpServer: &testDHCP{},
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example": knownIP,
},
}
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: dns.Fqdn(tc.host),
Qtype: tc.qtyp,
Qclass: dns.ClassINET,
}},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
isLocalClient: true,
}
res := s.processInternalHosts(dctx)
res := s.processDHCPHosts(dctx)
pctx := dctx.proxyCtx
assert.Equal(t, tc.wantRes, res)
if tc.wantRes == resultCodeFinish {

View File

@@ -33,11 +33,6 @@ const DefaultTimeout = 10 * time.Second
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
const (
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
parentalBlockHost = "family-block.dns.adguard.com"
)
var defaultDNS = []string{
"https://dns10.quad9.net/dns-query",
}
@@ -66,7 +61,7 @@ type Server struct {
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
stats stats.Stats
stats stats.Interface
access *accessCtx
// localDomainSuffix is the suffix used to detect internal hosts. It
@@ -74,7 +69,7 @@ type Server struct {
localDomainSuffix string
ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector
privateNets netutil.SubnetSet
localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers
recDetector *recursionDetector
@@ -107,28 +102,17 @@ type Server struct {
// when no suffix is provided.
//
// See the documentation for Server.localDomainSuffix.
const defaultLocalDomainSuffix = ".lan."
const defaultLocalDomainSuffix = "lan"
// DNSCreateParams are parameters to create a new server.
type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Stats
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
SubnetDetector *aghnet.SubnetDetector
Anonymizer *aghnet.IPMut
LocalDomain string
}
// domainNameToSuffix converts a domain name into a local domain suffix.
func domainNameToSuffix(tld string) (suffix string) {
l := len(tld) + 2
b := make([]byte, l)
b[0] = '.'
copy(b[1:], tld)
b[l-1] = '.'
return string(b)
DNSFilter *filtering.DNSFilter
Stats stats.Interface
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
LocalDomain string
}
const (
@@ -151,7 +135,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
return nil, fmt.Errorf("local domain: %w", err)
}
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
localDomainSuffix = p.LocalDomain
}
if p.Anonymizer == nil {
@@ -161,7 +145,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
subnetDetector: p.SubnetDetector,
privateNets: p.PrivateNets,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
@@ -173,7 +157,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Enable the refresher after the actual implementation
// passes the public testing.
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil)
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
if err != nil {
return nil, fmt.Errorf("initializing system resolvers: %w", err)
}
@@ -314,14 +298,16 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
StartTime: time.Now(),
}
resolver := s.internalProxy
if s.subnetDetector.IsLocallyServedNetwork(ip) {
var resolver *proxy.Proxy
if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS {
return "", nil
}
resolver = s.localResolvers
s.recDetector.add(*req)
} else {
resolver = s.internalProxy
}
if err = resolver.Resolve(ctx); err != nil {

View File

@@ -17,13 +17,14 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
@@ -69,14 +70,11 @@ func createTestServer(
f := filtering.New(filterConf, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var err error
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
f := filtering.New(&filtering.Config{}, filters)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -860,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru"
sbUps := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: true,
}
sbUps := aghtest.NewBlockUpstream(hostname, true)
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &filtering.Config{
@@ -913,15 +903,10 @@ func TestRewrite(t *testing.T) {
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1000,7 +985,7 @@ func TestRewrite(t *testing.T) {
}
}
func publicKey(priv interface{}) interface{} {
func publicKey(priv any) any {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
@@ -1028,36 +1013,33 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
const localDomain = "lan"
var s *Server
s, err = NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
LocalDomain: localDomain,
})
require.NoError(t, err)
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
s.conf.ProtectionEnabled = true
err = s.Prepare(nil)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(s.Close)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
require.NoErrorf(t, err, "%s", addr)
require.Len(t, resp.Answer, 1)
@@ -1066,7 +1048,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "myhost.", ptr.Ptr)
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
}
func TestPTRResponseFromHosts(t *testing.T) {
@@ -1105,16 +1087,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: flt,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1197,25 +1174,48 @@ func TestNewServer(t *testing.T) {
}
func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
const (
onesHost = "one.one.one.one"
localDomainHost = "local.domain"
)
var (
onesIP = net.IP{1, 1, 1, 1}
localIP = net.IP{192, 168, 1, 1}
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
},
}
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.168.192.in-addr.arpa.": {"local.domain"},
"2.1.168.192.in-addr.arpa.": {},
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
},
}
upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{
Err: upstreamErr,
}
nonPtrUpstream := &aghtest.TestBlockUpstream{
Hostname: "some-host",
Block: true,
}
errUpstream := aghtest.NewErrorUpstream()
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
srv := NewCustomServer(&proxy.Proxy{
Config: proxy.Config{
@@ -1227,11 +1227,8 @@ func TestServer_Exchange(t *testing.T) {
srv.conf.ResolveClients = true
srv.conf.UsePrivateRDNS = true
var err error
srv.subnetDetector, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := net.IP{192, 168, 1, 1}
testCases := []struct {
name string
want string
@@ -1240,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
req net.IP
}{{
name: "external_good",
want: "one.one.one.one",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: net.IP{1, 1, 1, 1},
req: onesIP,
}, {
name: "local_good",
want: "local.domain",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpstream,
req: localIP,
}, {
name: "upstream_error",
want: "",
wantErr: upstreamErr,
wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream,
req: localIP,
}, {

View File

@@ -22,7 +22,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
Preference: 32,
}
svcbVal := &rules.DNSSVCB{
Params: map[string]string{"alpn": "h3"},
Params: map[string]string{"alpn": "h3", "dohpath": "/dns-query"},
Target: dns.Fqdn(domain),
Priority: 32,
}
@@ -164,10 +164,20 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
require.Len(t, d.Res.Answer, 1)
ans, ok := d.Res.Answer[0].(*dns.SVCB)
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
require.True(t, ok)
require.Len(t, ans.Value, 2)
assert.ElementsMatch(
t,
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
)
assert.ElementsMatch(
t,
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
[]string{ans.Value[0].String(), ans.Value[1].String()},
)
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})
@@ -186,8 +196,18 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
require.Len(t, ans.Value, 2)
assert.ElementsMatch(
t,
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
)
assert.ElementsMatch(
t,
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
[]string{ans.Value[0].String(), ans.Value[1].String()},
)
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})

View File

@@ -4,7 +4,6 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)

View File

@@ -5,12 +5,10 @@ import (
"fmt"
"net"
"net/http"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -18,6 +16,8 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
type dnsConfig struct {
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
}
// validate returns an error if any field of req is invalid.
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
if err != nil {
@@ -176,7 +176,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err)
}
@@ -224,7 +224,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
err = req.validate(s.subnetDetector)
err = req.validate(s.privateNets)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -350,17 +350,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// LocalNetChecker is used to check if the IP address belongs to a local
// network.
type LocalNetChecker interface {
// IsLocallyServedNetwork returns true if ip is contained in any of address
// registries defined by RFC 6303.
IsLocallyServedNetwork(ip net.IP) (ok bool)
}
// type check
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
@@ -375,6 +364,21 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
return nil, nil
}
for _, u := range upstreams {
var ups string
var domains []string
ups, domains, err = separateUpstream(u)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
_, err = validateUpstream(ups, domains)
if err != nil {
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
}
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
@@ -385,13 +389,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
return nil, errors.Error("no default upstreams specified")
}
for _, u := range upstreams {
_, err = validateUpstream(u)
if err != nil {
return nil, err
}
}
return conf, nil
}
@@ -405,25 +402,11 @@ func ValidateUpstreams(upstreams []string) (err error) {
return err
}
// stringKeysSorted returns the sorted slice of string keys of m.
//
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
sorted = make([]string, 0, len(m))
for s := range m {
sorted = append(sorted, s)
}
sort.Strings(sorted)
return sorted
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. lnc must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return err
@@ -433,9 +416,11 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
return nil
}
var errs []error
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
var errs []error
for _, domain := range keys {
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain)
if err != nil {
@@ -444,7 +429,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
continue
}
if !lnc.IsLocallyServedNetwork(subnet.IP) {
if !privateNets.Contains(subnet.IP) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
@@ -461,16 +446,14 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
func validateUpstream(u string) (useDefault bool, err error) {
// Check if the user tries to specify upstream for domain.
var isDomainSpec bool
u, isDomainSpec, err = separateUpstream(u)
if err != nil {
return !isDomainSpec, err
}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
// domain-specific and is configured to point at the default upstream server
// which is validated separately. The upstream is considered domain-specific
// only if domains is at least not nil.
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
// The special server address '#' means that default server must be used.
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
if useDefault = u == "#" && domains != nil; useDefault {
return useDefault, nil
}
@@ -497,12 +480,14 @@ func validateUpstream(u string) (useDefault bool, err error) {
return useDefault, nil
}
// separateUpstream returns the upstream without the specified domains.
// isDomainSpec is true when the upstream is domains-specific.
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
// separateUpstream returns the upstream and the specified domains. domains is
// nil when the upstream is not domains-specific. Otherwise it may also be
// empty.
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, false, nil
return upstreamStr, nil, nil
}
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
parts := strings.Split(upstreamStr[2:], "/]")
@@ -510,39 +495,46 @@ func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, e
case 2:
// Go on.
case 1:
return "", false, errors.Error("missing separator")
return "", nil, errors.Error("missing separator")
default:
return "", true, errors.Error("duplicated separator")
return "", []string{}, errors.Error("duplicated separator")
}
var domains string
domains, upstream = parts[0], parts[1]
for i, host := range strings.Split(domains, "/") {
for i, host := range strings.Split(parts[0], "/") {
if host == "" {
continue
}
err = netutil.ValidateDomainName(host)
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
if err != nil {
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
}
domains = append(domains, host)
}
return upstream, true, nil
return parts[1], domains, nil
}
// excFunc is a signature of function to check if upstream exchanges correctly.
type excFunc func(u upstream.Upstream) (err error)
// healthCheckFunc is a signature of function to check if upstream exchanges
// properly.
type healthCheckFunc func(u upstream.Upstream) (err error)
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// testTLD is the special-use fully-qualified domain name for testing the
// DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
const testTLD = "test."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "google-public-dns-a.google.com.",
Name: testTLD,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
@@ -552,12 +544,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
reply, err = u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
}
if len(reply.Answer) != 1 {
return fmt.Errorf("wrong response")
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
return fmt.Errorf("wrong response")
} else if len(reply.Answer) != 0 {
return errors.Error("wrong response")
}
return nil
@@ -565,14 +553,22 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// checkPrivateUpstreamExc checks if the upstream for resolving private
// addresses exchanges correctly.
//
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
// address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
const inAddrArpaTLD = "in-addr.arpa."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "1.0.0.127.in-addr.arpa.",
Name: inAddrArpaTLD,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
@@ -585,46 +581,66 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
return nil
}
func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) {
if IsCommentOrEmpty(input) {
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
error
}
// checkDNS checks the upstream server defined by upstreamConfigStr using
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
// upstream's address.
func checkDNS(
upstreamConfigStr string,
bootstrap []string,
timeout time.Duration,
healthCheck healthCheckFunc,
) (err error) {
if IsCommentOrEmpty(upstreamConfigStr) {
return nil
}
// Separate upstream from domains list.
var useDefault bool
if useDefault, err = validateUpstream(input); err != nil {
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
// No need to check this DNS server.
if !useDefault {
useDefault, err := validateUpstream(upstreamAddr, domains)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
} else if useDefault {
return nil
}
if input, _, err = separateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
if len(bootstrap) == 0 {
bootstrap = defaultBootstrap
}
log.Debug("checking if upstream %s works", input)
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, &upstream.Options{
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
})
if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", input, err)
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
}
if err = ef(u); err != nil {
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
if err = healthCheck(u); err != nil {
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
if domains != nil {
return domainSpecificTestError{error: err}
}
return err
}
log.Debug("upstream %s is ok", input)
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
return nil
}
@@ -647,6 +663,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
if err != nil {
log.Info("%v", err)
result[host] = err.Error()
if _, ok := err.(domainSpecificTestError); ok {
result[host] = fmt.Sprintf("WARNING: %s", result[host])
}
continue
}
@@ -662,6 +681,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// above, we rewriting the error for it. These cases should be
// handled properly instead.
result[host] = err.Error()
if _, ok := err.(domainSpecificTestError); ok {
result[host] = fmt.Sprintf("WARNING: %s", result[host])
}
continue
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,7 +34,7 @@ func (fsr *fakeSystemResolvers) Get() (rs []string) {
return nil
}
func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
func loadTestData(t *testing.T, casesFileName string, cases any) {
t.Helper()
var f *os.File
@@ -184,7 +185,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "upstream_dns_bad",
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
wantSet: `validating upstream servers: ` +
`validating upstream "!!!": bad ipport address "!!!": ` +
`address !!!: missing port in address`,
}, {
name: "bootstraps_bad",
@@ -255,112 +257,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
}
}
func TestValidateUpstream(t *testing.T) {
testCases := []struct {
wantDef assert.BoolAssertionFunc
name string
upstream string
wantErr string
}{{
wantDef: assert.True,
name: "invalid",
upstream: "1.2.3.4.5",
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "123.3.7m",
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "htttps://google.com/dns-query",
wantErr: `wrong protocol`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[/host.com]tls://dns.adguard.com",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[host.ru]#",
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "tls://1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "https://dns.adguard.com/dns-query",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
wantErr: ``,
}, {
wantDef: assert.True,
name: "default_udp_host",
upstream: "udp://dns.google",
}, {
wantDef: assert.True,
name: "default_udp_ip",
upstream: "udp://8.8.8.8",
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/]1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[//]tls://1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/www.host.com/]#",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/google.com/]8.8.8.8",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
wantErr: ``,
}, {
wantDef: assert.False,
name: "idna",
upstream: "[/пример.рф/]8.8.8.8",
wantErr: ``,
}, {
wantDef: assert.False,
name: "bad_domain",
upstream: "[/!/]8.8.8.8",
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream)
testutil.AssertErrorMsg(t, tc.wantErr, err)
tc.wantDef(t, defaultUpstream)
})
}
}
func TestValidateUpstreams(t *testing.T) {
testCases := []struct {
name string
@@ -375,7 +271,7 @@ func TestValidateUpstreams(t *testing.T) {
wantErr: ``,
set: []string{"# comment"},
}, {
name: "valid_no_default",
name: "no_default",
wantErr: `no default upstreams specified`,
set: []string{
"[/host.com/]1.1.1.1",
@@ -385,7 +281,7 @@ func TestValidateUpstreams(t *testing.T) {
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
},
}, {
name: "valid_with_default",
name: "with_default",
wantErr: ``,
set: []string{
"[/host.com/]1.1.1.1",
@@ -397,8 +293,46 @@ func TestValidateUpstreams(t *testing.T) {
},
}, {
name: "invalid",
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
set: []string{"dhcp://fake.dns"},
}, {
name: "invalid",
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
set: []string{"1.2.3.4.5"},
}, {
name: "invalid",
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
set: []string{"123.3.7m"},
}, {
name: "invalid",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
}, {
name: "invalid",
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
set: []string{"[host.ru]#"},
}, {
name: "valid_default",
wantErr: ``,
set: []string{
"1.1.1.1",
"tls://1.1.1.1",
"https://dns.adguard.com/dns-query",
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"udp://dns.google",
"udp://8.8.8.8",
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"[/пример.рф/]8.8.8.8",
},
}, {
name: "bad_domain",
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
set: []string{"[/!/]8.8.8.8"},
}}
for _, tc := range testCases {
@@ -410,8 +344,7 @@ func TestValidateUpstreams(t *testing.T) {
}
func TestValidateUpstreamsPrivate(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
@@ -452,7 +385,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err = ValidateUpstreamsPrivate(set, snd)
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}

View File

@@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) {
testCases := []struct {
name string
msg dns.Msg
want bool
want int
}{{
name: "simple",
msg: dns.Msg{
@@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) {
Qtype: dns.TypeA,
}},
},
want: true,
want: 1,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: false,
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
if tc.want {
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
} else {
assert.Zero(t, rd.recentRequests.Stats().Count)
}
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}

View File

@@ -64,9 +64,9 @@ func (s *Server) logQuery(
Answer: pctx.Res,
OrigAnswer: dctx.origResp,
Result: dctx.result,
Elapsed: elapsed,
ClientID: dctx.clientID,
ClientIP: ip,
Elapsed: elapsed,
AuthenticatedData: dctx.responseAD,
}

View File

@@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// actually implementing all methods.
stats.Stats
stats.Interface
lastEntry stats.Entry
}

View File

@@ -32,12 +32,16 @@ func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTT
// github.com/miekg/dns module.
var strToSVCBKey = map[string]dns.SVCBKey{
"alpn": dns.SVCB_ALPN,
"echconfig": dns.SVCB_ECHCONFIG,
"ech": dns.SVCB_ECHCONFIG,
"ipv4hint": dns.SVCB_IPV4HINT,
"ipv6hint": dns.SVCB_IPV6HINT,
"mandatory": dns.SVCB_MANDATORY,
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
"port": dns.SVCB_PORT,
// TODO(a.garipov): This is the previous name for the parameter that has
// since been changed. Remove this in v0.109.0.
"echconfig": dns.SVCB_ECHCONFIG,
}
// svcbKeyHandler is a handler for one SVCB parameter key.
@@ -51,10 +55,10 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
}
},
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
"ech": func(valStr string) (val dns.SVCBKeyValue) {
ech, err := base64.StdEncoding.DecodeString(valStr)
if err != nil {
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
return nil
}
@@ -119,6 +123,32 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
Port: uint16(port64),
}
},
// TODO(a.garipov): This is the previous name for the parameter that has
// since been changed. Remove this in v0.109.0.
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
log.Info(
`warning: svcb/https record parameter name "echconfig" is deprecated; ` +
`use "ech" instead`,
)
ech, err := base64.StdEncoding.DecodeString(valStr)
if err != nil {
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
return nil
}
return &dns.SVCBECHConfig{
ECH: ech,
}
},
"dohpath": func(valStr string) (val dns.SVCBKeyValue) {
return &dns.SVCBDoHPath{
Template: valStr,
}
},
}
// genAnswerSVCB returns a properly initialized SVCB resource record.

View File

@@ -87,14 +87,18 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
svcb: dnssvcb("alpn", "h3"),
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
name: "alpn",
}, {
svcb: dnssvcb("ech", "AAAA"),
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
name: "ech",
}, {
svcb: dnssvcb("echconfig", "AAAA"),
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
name: "echconfig",
name: "ech_deprecated",
}, {
svcb: dnssvcb("echconfig", "%BAD%"),
want: wantsvcb(nil),
name: "echconfig_invalid",
name: "ech_invalid",
}, {
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
@@ -123,6 +127,10 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
svcb: dnssvcb("no-default-alpn", ""),
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
name: "no_default_alpn",
}, {
svcb: dnssvcb("dohpath", "/dns-query"),
want: wantsvcb(&dns.SVCBDoHPath{Template: "/dns-query"}),
name: "dohpath",
}, {
svcb: dnssvcb("port", "8080"),
want: wantsvcb(&dns.SVCBPort{Port: 8080}),