Merge branch 'master' into 4403-internal-proxy
This commit is contained in:
@@ -214,7 +214,7 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
merged := allowed.Merge(disallowed)
|
||||
err = merged.Validate(aghalg.StringIsBefore)
|
||||
err = merged.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
|
||||
}
|
||||
@@ -223,13 +223,13 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
// validateStrUniq returns an informative error if clients are not unique.
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
|
||||
uc = make(aghalg.UniqChecker, len(clients))
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error) {
|
||||
uc = make(aghalg.UniqChecker[string], len(clients))
|
||||
for _, c := range clients {
|
||||
uc.Add(c)
|
||||
}
|
||||
|
||||
return uc, uc.Validate(aghalg.StringIsBefore)
|
||||
return uc, uc.Validate()
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -65,7 +65,7 @@ func clientIDFromClientServerName(
|
||||
return "", err
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
@@ -104,7 +104,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@@ -112,8 +112,8 @@ type tlsConn interface {
|
||||
ConnectionState() (cs tls.ConnectionState)
|
||||
}
|
||||
|
||||
// quicSession is a narrow interface for quic.Session to simplify testing.
|
||||
type quicSession interface {
|
||||
// quicConnection is a narrow interface for quic.Connection to simplify testing.
|
||||
type quicConnection interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
@@ -148,16 +148,16 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
conn, ok := pctx.QUICConnection.(quicConnection)
|
||||
if !ok {
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
pctx.QUICConnection,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
cliSrvName = conn.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
|
||||
@@ -29,17 +29,18 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
// testQUICSession is a quicSession for tests.
|
||||
type testQUICSession struct {
|
||||
// Session is embedded here simply to make testQUICSession a quic.Session
|
||||
// without actually implementing all methods.
|
||||
quic.Session
|
||||
// testQUICConnection is a quicConnection for tests.
|
||||
type testQUICConnection struct {
|
||||
// Connection is embedded here simply to make testQUICConnection a
|
||||
// quic.Connection without actually implementing all methods.
|
||||
quic.Connection
|
||||
|
||||
serverName string
|
||||
}
|
||||
|
||||
// ConnectionState implements the quicSession interface for testQUICSession.
|
||||
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
// ConnectionState implements the quicConnection interface for
|
||||
// testQUICConnection.
|
||||
func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
|
||||
cs.TLS.ServerName = c.serverName
|
||||
|
||||
return cs
|
||||
@@ -143,6 +144,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_case",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_case",
|
||||
proto: proxy.ProtoQUIC,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -163,17 +180,17 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var qs quic.Session
|
||||
var qconn quic.Connection
|
||||
if tc.proto == proxy.ProtoQUIC {
|
||||
qs = testQUICSession{
|
||||
qconn = testQUICConnection{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICConnection: qconn,
|
||||
}
|
||||
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
@@ -210,6 +227,11 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid_case",
|
||||
path: "/dns-query/InSeNsItIvE",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -122,6 +122,7 @@ type FilteringConfig struct {
|
||||
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
|
||||
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
|
||||
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
|
||||
HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests
|
||||
|
||||
// IpsetList is the ipset configuration that allows AdGuard Home to add
|
||||
// IP addresses of the specified domain names to an ipset list. Syntax:
|
||||
@@ -133,8 +134,9 @@ type FilteringConfig struct {
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
|
||||
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
@@ -151,7 +153,7 @@ type TLSConfig struct {
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// ServerName is the hostname of the server. Currently, it is only being
|
||||
// used for ClientID checking.
|
||||
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate
|
||||
@@ -191,7 +193,7 @@ type ServerConfig struct {
|
||||
ConfigModified func()
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
||||
ResolveClients bool
|
||||
@@ -276,6 +278,11 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
return proxyConfig, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
defaultParentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
// is configured
|
||||
func (s *Server) initDefaultSettings() {
|
||||
@@ -287,12 +294,12 @@ func (s *Server) initDefaultSettings() {
|
||||
s.conf.BootstrapDNS = defaultBootstrap
|
||||
}
|
||||
|
||||
if len(s.conf.ParentalBlockHost) == 0 {
|
||||
s.conf.ParentalBlockHost = parentalBlockHost
|
||||
if s.conf.ParentalBlockHost == "" {
|
||||
s.conf.ParentalBlockHost = defaultParentalBlockHost
|
||||
}
|
||||
|
||||
if len(s.conf.SafeBrowsingBlockHost) == 0 {
|
||||
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
|
||||
if s.conf.SafeBrowsingBlockHost == "" {
|
||||
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
|
||||
}
|
||||
|
||||
if s.conf.UDPListenAddrs == nil {
|
||||
|
||||
@@ -76,6 +76,10 @@ const (
|
||||
resultCodeError
|
||||
)
|
||||
|
||||
// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests.
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
const ddrHostFQDN = "_dns.resolver.arpa."
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{
|
||||
@@ -94,10 +98,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDetermineLocal,
|
||||
s.processInternalHosts,
|
||||
s.processDHCPHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@@ -135,7 +140,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -226,12 +230,10 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
)
|
||||
}
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
ip := netutil.CloneIP(l.IP)
|
||||
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
ipToHost.Set(ip, lowhost)
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
@@ -242,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
s.setTableIPToHost(ipToHost)
|
||||
}
|
||||
|
||||
// processDDRQuery responds to SVCB query for a special use domain name
|
||||
// ‘_dns.resolver.arpa’. The result contains different types of encryption
|
||||
// supported by current user configuration.
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
question := d.Req.Question[0]
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if question.Name == ddrHostFQDN {
|
||||
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
|
||||
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
|
||||
d.Res = s.makeResponse(d.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
d.Res = s.makeDDRResponse(d.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// makeDDRResponse creates DDR answer according to server configuration. The
|
||||
// contructed SVCB resource records have the priority of 1 for each entry,
|
||||
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
//
|
||||
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
// TODO(e.burkov): Think about storing the FQDN version of the server's
|
||||
// name somewhere.
|
||||
domainName := dns.Fqdn(s.conf.ServerName)
|
||||
|
||||
for _, addr := range s.conf.HTTPSListenAddrs {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.TLSListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.QUICListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// processDetermineLocal determines if the client's IP address is from
|
||||
// locally-served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
@@ -252,7 +346,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return rc
|
||||
}
|
||||
|
||||
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
|
||||
dctx.isLocalClient = s.privateNets.Contains(ip)
|
||||
|
||||
return rc
|
||||
}
|
||||
@@ -280,11 +374,11 @@ func (s *Server) hostToIP(host string) (ip net.IP, ok bool) {
|
||||
return ip, true
|
||||
}
|
||||
|
||||
// processInternalHosts respond to A requests if the target hostname is known to
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server.
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
if !s.dhcpServer.Enabled() {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -299,11 +393,10 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
reqHost := strings.ToLower(q.Name)
|
||||
reqHost := strings.ToLower(q.Name[:len(q.Name)-1])
|
||||
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
|
||||
// server.
|
||||
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
|
||||
if host == reqHost {
|
||||
if !strings.HasSuffix(reqHost, s.localDomainSuffix) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -316,7 +409,7 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
ip, ok := s.hostToIP(host)
|
||||
ip, ok := s.hostToIP(reqHost)
|
||||
if !ok {
|
||||
// TODO(e.burkov): Inspect special cases when user want to apply some
|
||||
// rules handled by other processors to the hosts with TLD.
|
||||
@@ -373,8 +466,8 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
// don't need to be accessible externally.
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -413,7 +506,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
if !ok {
|
||||
return "", false
|
||||
@@ -430,7 +523,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
|
||||
// Respond to PTR requests if the target IP is leased by our DHCP server and the
|
||||
// requestor is inside the local network.
|
||||
func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -481,7 +574,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,212 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
s := &Server{
|
||||
subnetDetector: snd,
|
||||
const (
|
||||
ddrTestDomainName = "dns.example.net"
|
||||
ddrTestFQDN = ddrTestDomainName + "."
|
||||
)
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: 8044},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
},
|
||||
}
|
||||
|
||||
dotSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: 8043},
|
||||
},
|
||||
}
|
||||
|
||||
doqSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: 8042},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want []*dns.SVCB
|
||||
wantRes resultCode
|
||||
portDoH int
|
||||
portDoT int
|
||||
portDoQ int
|
||||
qtype uint16
|
||||
ddrEnabled bool
|
||||
}{{
|
||||
name: "pass_host",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: "example.net.",
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "pass_qtype",
|
||||
wantRes: resultCodeFinish,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeA,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "pass_disabled_tls",
|
||||
wantRes: resultCodeFinish,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
}, {
|
||||
name: "pass_disabled_ddr",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: false,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "dot",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dotSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
}, {
|
||||
name: "doh",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dohSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8044,
|
||||
}, {
|
||||
name: "doq",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{doqSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoQ: 8042,
|
||||
}, {
|
||||
name: "dot_doh",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dotSVCB, dohSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
portDoH: 8044,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
|
||||
|
||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
}
|
||||
|
||||
res := s.processDDRQuery(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
|
||||
if tc.wantRes != resultCodeFinish {
|
||||
return
|
||||
}
|
||||
|
||||
msg := dctx.proxyCtx.Res
|
||||
require.NotNil(t, msg)
|
||||
|
||||
for _, v := range tc.want {
|
||||
v.Hdr = s.hdr(req, dns.TypeSVCB)
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, tc.want, msg.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
proxyConf := proxy.Config{}
|
||||
|
||||
if portDoT > 0 {
|
||||
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
s = &Server{
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxyConf,
|
||||
},
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
HandleDDR: ddrEnabled,
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if portDoH > 0 {
|
||||
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliIP net.IP
|
||||
want bool
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliIP: net.IP{192, 168, 0, 1},
|
||||
want: true,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliIP: net.IP{250, 249, 0, 1},
|
||||
want: false,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliIP: net.IP{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "nil",
|
||||
cliIP: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -47,12 +224,12 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
assert.Equal(t, tc.want, dctx.isLocalClient)
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -93,7 +270,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: defaultLocalDomainSuffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
"example." + defaultLocalDomainSuffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -115,7 +292,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
isLocalClient: tc.isLocalCli,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
pctx := dctx.proxyCtx
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
@@ -141,10 +318,10 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
const (
|
||||
examplecom = "example.com"
|
||||
examplelan = "example.lan"
|
||||
examplelan = "example." + defaultLocalDomainSuffix
|
||||
)
|
||||
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
@@ -193,41 +370,41 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
}, {
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
suffix: ".custom.",
|
||||
suffix: "custom",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + tc.suffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
|
||||
@@ -33,11 +33,6 @@ const DefaultTimeout = 10 * time.Second
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
var defaultDNS = []string{
|
||||
"https://dns10.quad9.net/dns-query",
|
||||
}
|
||||
@@ -66,7 +61,7 @@ type Server struct {
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Stats
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
@@ -74,7 +69,7 @@ type Server struct {
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
privateNets netutil.SubnetSet
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
recDetector *recursionDetector
|
||||
@@ -107,28 +102,17 @@ type Server struct {
|
||||
// when no suffix is provided.
|
||||
//
|
||||
// See the documentation for Server.localDomainSuffix.
|
||||
const defaultLocalDomainSuffix = ".lan."
|
||||
const defaultLocalDomainSuffix = "lan"
|
||||
|
||||
// DNSCreateParams are parameters to create a new server.
|
||||
type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
SubnetDetector *aghnet.SubnetDetector
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
// domainNameToSuffix converts a domain name into a local domain suffix.
|
||||
func domainNameToSuffix(tld string) (suffix string) {
|
||||
l := len(tld) + 2
|
||||
b := make([]byte, l)
|
||||
b[0] = '.'
|
||||
copy(b[1:], tld)
|
||||
b[l-1] = '.'
|
||||
|
||||
return string(b)
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Interface
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -151,7 +135,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
return nil, fmt.Errorf("local domain: %w", err)
|
||||
}
|
||||
|
||||
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
|
||||
localDomainSuffix = p.LocalDomain
|
||||
}
|
||||
|
||||
if p.Anonymizer == nil {
|
||||
@@ -161,7 +145,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
subnetDetector: p.SubnetDetector,
|
||||
privateNets: p.PrivateNets,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
@@ -173,7 +157,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
// passes the public testing.
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil)
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
@@ -314,14 +298,16 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
resolver := s.internalProxy
|
||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
var resolver *proxy.Proxy
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
s.recDetector.add(*req)
|
||||
} else {
|
||||
resolver = s.internalProxy
|
||||
}
|
||||
|
||||
if err = resolver.Resolve(ctx); err != nil {
|
||||
|
||||
@@ -17,13 +17,14 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -69,14 +70,11 @@ func createTestServer(
|
||||
f := filtering.New(filterConf, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -860,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
const hostname = "wmconvirus.narod.ru"
|
||||
|
||||
sbUps := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: true,
|
||||
}
|
||||
sbUps := aghtest.NewBlockUpstream(hostname, true)
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
@@ -913,15 +903,10 @@ func TestRewrite(t *testing.T) {
|
||||
f := filtering.New(c, nil)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1000,7 +985,7 @@ func TestRewrite(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
func publicKey(priv any) any {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
@@ -1028,36 +1013,33 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
const localDomain = "lan"
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
s.conf.ProtectionEnabled = true
|
||||
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "%s", addr)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
@@ -1066,7 +1048,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "myhost.", ptr.Ptr)
|
||||
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
@@ -1105,16 +1087,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
}, nil)
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1197,25 +1174,48 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
localDomainHost = "local.domain"
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = net.IP{1, 1, 1, 1}
|
||||
localIP = net.IP{192, 168, 1, 1}
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
|
||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
upstreamErr := errors.Error("upstream error")
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: upstreamErr,
|
||||
}
|
||||
nonPtrUpstream := &aghtest.TestBlockUpstream{
|
||||
Hostname: "some-host",
|
||||
Block: true,
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
||||
|
||||
srv := NewCustomServer(&proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
@@ -1227,11 +1227,8 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
var err error
|
||||
srv.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
@@ -1240,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
req net.IP
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: "one.one.one.one",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: net.IP{1, 1, 1, 1},
|
||||
req: onesIP,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: "local.domain",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: upstreamErr,
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
Preference: 32,
|
||||
}
|
||||
svcbVal := &rules.DNSSVCB{
|
||||
Params: map[string]string{"alpn": "h3"},
|
||||
Params: map[string]string{"alpn": "h3", "dohpath": "/dns-query"},
|
||||
Target: dns.Fqdn(domain),
|
||||
Priority: 32,
|
||||
}
|
||||
@@ -164,10 +164,20 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
|
||||
require.True(t, ok)
|
||||
require.Len(t, ans.Value, 2)
|
||||
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
|
||||
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
|
||||
)
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
|
||||
[]string{ans.Value[0].String(), ans.Value[1].String()},
|
||||
)
|
||||
assert.Equal(t, svcbVal.Target, ans.Target)
|
||||
assert.Equal(t, svcbVal.Priority, ans.Priority)
|
||||
})
|
||||
@@ -186,8 +196,18 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
|
||||
require.Len(t, ans.Value, 2)
|
||||
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
|
||||
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
|
||||
)
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
|
||||
[]string{ans.Value[0].String(), ans.Value[1].String()},
|
||||
)
|
||||
assert.Equal(t, svcbVal.Target, ans.Target)
|
||||
assert.Equal(t, svcbVal.Priority, ans.Priority)
|
||||
})
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -18,6 +16,8 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type dnsConfig struct {
|
||||
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
@@ -176,7 +176,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(s.subnetDetector)
|
||||
err = req.validate(s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -350,17 +350,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// LocalNetChecker is used to check if the IP address belongs to a local
|
||||
// network.
|
||||
type LocalNetChecker interface {
|
||||
// IsLocallyServedNetwork returns true if ip is contained in any of address
|
||||
// registries defined by RFC 6303.
|
||||
IsLocallyServedNetwork(ip net.IP) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
@@ -375,6 +364,21 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
var ups string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = validateUpstream(ups, domains)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
@@ -385,13 +389,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -405,25 +402,11 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. lnc must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -433,9 +416,11 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
@@ -444,7 +429,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
continue
|
||||
}
|
||||
|
||||
if !lnc.IsLocallyServedNetwork(subnet.IP) {
|
||||
if !privateNets.Contains(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
@@ -461,16 +446,14 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
|
||||
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
|
||||
|
||||
func validateUpstream(u string) (useDefault bool, err error) {
|
||||
// Check if the user tries to specify upstream for domain.
|
||||
var isDomainSpec bool
|
||||
u, isDomainSpec, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
return !isDomainSpec, err
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
@@ -497,12 +480,14 @@ func validateUpstream(u string) (useDefault bool, err error) {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// isDomainSpec is true when the upstream is domains-specific.
|
||||
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, false, nil
|
||||
return upstreamStr, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
@@ -510,39 +495,46 @@ func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, e
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return "", false, errors.Error("missing separator")
|
||||
return "", nil, errors.Error("missing separator")
|
||||
default:
|
||||
return "", true, errors.Error("duplicated separator")
|
||||
return "", []string{}, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
var domains string
|
||||
domains, upstream = parts[0], parts[1]
|
||||
for i, host := range strings.Split(domains, "/") {
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(host)
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return upstream, true, nil
|
||||
return parts[1], domains, nil
|
||||
}
|
||||
|
||||
// excFunc is a signature of function to check if upstream exchanges correctly.
|
||||
type excFunc func(u upstream.Upstream) (err error)
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
const testTLD = "test."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "google-public-dns-a.google.com.",
|
||||
Name: testTLD,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -552,12 +544,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
reply, err = u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
if len(reply.Answer) != 1 {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -565,14 +553,22 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
const inAddrArpaTLD = "in-addr.arpa."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "1.0.0.127.in-addr.arpa.",
|
||||
Name: inAddrArpaTLD,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -585,46 +581,66 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) {
|
||||
if IsCommentOrEmpty(input) {
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// checkDNS checks the upstream server defined by upstreamConfigStr using
|
||||
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
|
||||
// upstream's address.
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
if IsCommentOrEmpty(upstreamConfigStr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Separate upstream from domains list.
|
||||
var useDefault bool
|
||||
if useDefault, err = validateUpstream(input); err != nil {
|
||||
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// No need to check this DNS server.
|
||||
if !useDefault {
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
if input, _, err = separateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
if len(bootstrap) == 0 {
|
||||
bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
log.Debug("checking if upstream %s works", input)
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", input, err)
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
if err = ef(u); err != nil {
|
||||
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
if domains != nil {
|
||||
return domainSpecificTestError{error: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("upstream %s is ok", input)
|
||||
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -647,6 +663,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Info("%v", err)
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -662,6 +681,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// above, we rewriting the error for it. These cases should be
|
||||
// handled properly instead.
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,7 +34,7 @@ func (fsr *fakeSystemResolvers) Get() (rs []string) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
|
||||
func loadTestData(t *testing.T, casesFileName string, cases any) {
|
||||
t.Helper()
|
||||
|
||||
var f *os.File
|
||||
@@ -184,7 +185,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
|
||||
wantSet: `validating upstream servers: ` +
|
||||
`validating upstream "!!!": bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
@@ -255,112 +257,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantDef assert.BoolAssertionFunc
|
||||
name string
|
||||
upstream string
|
||||
wantErr string
|
||||
}{{
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
wantErr: `wrong protocol`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_host",
|
||||
upstream: "udp://dns.google",
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_ip",
|
||||
upstream: "udp://8.8.8.8",
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "idna",
|
||||
upstream: "[/пример.рф/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "bad_domain",
|
||||
upstream: "[/!/]8.8.8.8",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
tc.wantDef(t, defaultUpstream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -375,7 +271,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
wantErr: ``,
|
||||
set: []string{"# comment"},
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
name: "no_default",
|
||||
wantErr: `no default upstreams specified`,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -385,7 +281,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
name: "with_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -397,8 +293,46 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
|
||||
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
set: []string{"1.2.3.4.5"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
set: []string{"[host.ru]#"},
|
||||
}, {
|
||||
name: "valid_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
set: []string{"[/!/]8.8.8.8"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -410,8 +344,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -452,7 +385,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = ValidateUpstreamsPrivate(set, snd)
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg dns.Msg
|
||||
want bool
|
||||
want int
|
||||
}{{
|
||||
name: "simple",
|
||||
msg: dns.Msg{
|
||||
@@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
Qtype: dns.TypeA,
|
||||
}},
|
||||
},
|
||||
want: true,
|
||||
want: 1,
|
||||
}, {
|
||||
name: "unencumbered",
|
||||
msg: dns.Msg{},
|
||||
want: false,
|
||||
want: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(rd.clear)
|
||||
|
||||
rd.add(tc.msg)
|
||||
|
||||
if tc.want {
|
||||
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
|
||||
} else {
|
||||
assert.Zero(t, rd.recentRequests.Stats().Count)
|
||||
}
|
||||
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,9 +64,9 @@ func (s *Server) logQuery(
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: dctx.origResp,
|
||||
Result: dctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientID: dctx.clientID,
|
||||
ClientIP: ip,
|
||||
Elapsed: elapsed,
|
||||
AuthenticatedData: dctx.responseAD,
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
|
||||
type testStats struct {
|
||||
// Stats is embedded here simply to make testStats a stats.Stats without
|
||||
// actually implementing all methods.
|
||||
stats.Stats
|
||||
stats.Interface
|
||||
|
||||
lastEntry stats.Entry
|
||||
}
|
||||
|
||||
@@ -32,12 +32,16 @@ func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTT
|
||||
// github.com/miekg/dns module.
|
||||
var strToSVCBKey = map[string]dns.SVCBKey{
|
||||
"alpn": dns.SVCB_ALPN,
|
||||
"echconfig": dns.SVCB_ECHCONFIG,
|
||||
"ech": dns.SVCB_ECHCONFIG,
|
||||
"ipv4hint": dns.SVCB_IPV4HINT,
|
||||
"ipv6hint": dns.SVCB_IPV6HINT,
|
||||
"mandatory": dns.SVCB_MANDATORY,
|
||||
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
|
||||
"port": dns.SVCB_PORT,
|
||||
|
||||
// TODO(a.garipov): This is the previous name for the parameter that has
|
||||
// since been changed. Remove this in v0.109.0.
|
||||
"echconfig": dns.SVCB_ECHCONFIG,
|
||||
}
|
||||
|
||||
// svcbKeyHandler is a handler for one SVCB parameter key.
|
||||
@@ -51,10 +55,10 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
}
|
||||
},
|
||||
|
||||
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
"ech": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
ech, err := base64.StdEncoding.DecodeString(valStr)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
|
||||
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -119,6 +123,32 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
Port: uint16(port64),
|
||||
}
|
||||
},
|
||||
|
||||
// TODO(a.garipov): This is the previous name for the parameter that has
|
||||
// since been changed. Remove this in v0.109.0.
|
||||
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
log.Info(
|
||||
`warning: svcb/https record parameter name "echconfig" is deprecated; ` +
|
||||
`use "ech" instead`,
|
||||
)
|
||||
|
||||
ech, err := base64.StdEncoding.DecodeString(valStr)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBECHConfig{
|
||||
ECH: ech,
|
||||
}
|
||||
},
|
||||
|
||||
"dohpath": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
return &dns.SVCBDoHPath{
|
||||
Template: valStr,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// genAnswerSVCB returns a properly initialized SVCB resource record.
|
||||
|
||||
@@ -87,14 +87,18 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
svcb: dnssvcb("alpn", "h3"),
|
||||
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
|
||||
name: "alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("ech", "AAAA"),
|
||||
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
|
||||
name: "ech",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "AAAA"),
|
||||
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
|
||||
name: "echconfig",
|
||||
name: "ech_deprecated",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "%BAD%"),
|
||||
want: wantsvcb(nil),
|
||||
name: "echconfig_invalid",
|
||||
name: "ech_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
|
||||
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
|
||||
@@ -123,6 +127,10 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
svcb: dnssvcb("no-default-alpn", ""),
|
||||
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
|
||||
name: "no_default_alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("dohpath", "/dns-query"),
|
||||
want: wantsvcb(&dns.SVCBDoHPath{Template: "/dns-query"}),
|
||||
name: "dohpath",
|
||||
}, {
|
||||
svcb: dnssvcb("port", "8080"),
|
||||
want: wantsvcb(&dns.SVCBPort{Port: 8080}),
|
||||
|
||||
Reference in New Issue
Block a user