all: sync with master
This commit is contained in:
@@ -3,24 +3,26 @@ package dnsforward
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
allowedIPs *netutil.IPMap
|
||||
blockedIPs *netutil.IPMap
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// accessManager controls IP and client blocking that takes place before all
|
||||
// other processing. An accessManager is safe for concurrent use.
|
||||
type accessManager struct {
|
||||
allowedIPs map[netip.Addr]unit
|
||||
blockedIPs map[netip.Addr]unit
|
||||
|
||||
allowedClientIDs *stringutil.Set
|
||||
blockedClientIDs *stringutil.Set
|
||||
@@ -28,36 +30,29 @@ type accessCtx struct {
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// netutil.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
allowedNets []netip.Prefix
|
||||
blockedNets []netip.Prefix
|
||||
}
|
||||
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// processAccessClients is a helper for processing a list of client strings,
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *netutil.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
ips map[netip.Addr]unit,
|
||||
nets *[]netip.Prefix,
|
||||
clientIDs *stringutil.Set,
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
ips.Set(ip, unit{})
|
||||
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||
ipnet.IP = cidrIP
|
||||
var ip netip.Addr
|
||||
var ipnet netip.Prefix
|
||||
if ip, err = netip.ParseAddr(s); err == nil {
|
||||
ips[ip] = unit{}
|
||||
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
idErr := ValidateClientID(s)
|
||||
if idErr != nil {
|
||||
return fmt.Errorf(
|
||||
"value %q at index %d: bad ip, cidr, or clientid",
|
||||
s,
|
||||
i,
|
||||
)
|
||||
err = ValidateClientID(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
|
||||
}
|
||||
|
||||
clientIDs.Add(s)
|
||||
@@ -68,10 +63,10 @@ func processAccessClients(
|
||||
}
|
||||
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: netutil.NewIPMap(0),
|
||||
blockedIPs: netutil.NewIPMap(0),
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessManager, err error) {
|
||||
a = &accessManager{
|
||||
allowedIPs: map[netip.Addr]unit{},
|
||||
blockedIPs: map[netip.Addr]unit{},
|
||||
|
||||
allowedClientIDs: stringutil.NewSet(),
|
||||
blockedClientIDs: stringutil.NewSet(),
|
||||
@@ -111,12 +106,12 @@ func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err er
|
||||
}
|
||||
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
func (a *accessManager) allowlistMode() (ok bool) {
|
||||
return len(a.allowedIPs) != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
func (a *accessManager) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without ClientIDs blocked by
|
||||
@@ -132,7 +127,7 @@ func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
}
|
||||
|
||||
// isBlockedHost returns true if host should be blocked.
|
||||
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
func (a *accessManager) isBlockedHost(host string) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||
|
||||
return ok
|
||||
@@ -140,7 +135,7 @@ func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
|
||||
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
blocked = true
|
||||
ips := a.blockedIPs
|
||||
ipnets := a.blockedNets
|
||||
@@ -152,7 +147,7 @@ func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips.Get(ip); ok {
|
||||
if _, ok := ips[ip]; ok {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
@@ -240,7 +235,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var a *accessCtx
|
||||
var a *accessManager
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -95,27 +95,27 @@ func TestIsBlockedIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRule string
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
wantBlocked bool
|
||||
}{{
|
||||
name: "match_ip",
|
||||
wantRule: "1.2.3.4",
|
||||
ip: net.IP{1, 2, 3, 4},
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "match_cidr",
|
||||
wantRule: "5.6.7.8/24",
|
||||
ip: net.IP{5, 6, 7, 100},
|
||||
ip: netip.MustParseAddr("5.6.7.100"),
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "no_match_ip",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 2, 3, 4},
|
||||
ip: netip.MustParseAddr("9.2.3.4"),
|
||||
wantBlocked: false,
|
||||
}, {
|
||||
name: "no_match_cidr",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 6, 7, 100},
|
||||
ip: netip.MustParseAddr("9.6.7.100"),
|
||||
wantBlocked: false,
|
||||
}}
|
||||
|
||||
|
||||
@@ -123,7 +123,14 @@ type quicConnection interface {
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
return clientIDFromDNSContextHTTPS(pctx)
|
||||
clientID, err = clientIDFromDNSContextHTTPS(pctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("checking url: %w", err)
|
||||
} else if clientID != "" {
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// Go on and check the domain name as well.
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return "", nil
|
||||
}
|
||||
@@ -133,31 +140,9 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName := ""
|
||||
switch proto {
|
||||
case proxy.ProtoTLS:
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx conn of proto %s is %T, want *tls.Conn",
|
||||
proto,
|
||||
conn,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
case proxy.ProtoQUIC:
|
||||
conn, ok := pctx.QUICConnection.(quicConnection)
|
||||
if !ok {
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
|
||||
proto,
|
||||
pctx.QUICConnection,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = conn.ConnectionState().TLS.ServerName
|
||||
cliSrvName, err := clientServerName(pctx, proto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
@@ -171,3 +156,47 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// clientServerName returns the TLS server name based on the protocol.
|
||||
func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string, err error) {
|
||||
switch proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
// github.com/lucas-clemente/quic-go seems to not populate the TLS
|
||||
// field. So, if the request comes over HTTP/3, use the Host header
|
||||
// value as the server name.
|
||||
//
|
||||
// See https://github.com/lucas-clemente/quic-go/issues/2879.
|
||||
//
|
||||
// TODO(a.garipov): Remove this crutch once they fix it.
|
||||
r := pctx.HTTPRequest
|
||||
if r.ProtoAtLeast(3, 0) {
|
||||
var host string
|
||||
host, err = netutil.SplitHost(r.Host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing host: %w", err)
|
||||
}
|
||||
|
||||
srvName = host
|
||||
} else if connState := r.TLS; connState != nil {
|
||||
srvName = r.TLS.ServerName
|
||||
}
|
||||
case proxy.ProtoQUIC:
|
||||
qConn := pctx.QUICConnection
|
||||
conn, ok := qConn.(quicConnection)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("pctx conn of proto %s is %T, want quic.Connection", proto, qConn)
|
||||
}
|
||||
|
||||
srvName = conn.ConnectionState().TLS.ServerName
|
||||
case proxy.ProtoTLS:
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("pctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||
}
|
||||
|
||||
srvName = tc.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
return srvName, nil
|
||||
}
|
||||
|
||||
@@ -47,8 +47,6 @@ func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
|
||||
}
|
||||
|
||||
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
// TODO(a.garipov): Consider moving away from the text-based error
|
||||
// checks and onto a more structured approach.
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto proxy.Proto
|
||||
@@ -57,6 +55,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
strictSNI bool
|
||||
useHTTP3 bool
|
||||
}{{
|
||||
name: "udp",
|
||||
proto: proxy.ProtoUDP,
|
||||
@@ -65,6 +64,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
strictSNI: false,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_no_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -73,6 +73,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_no_client_server_name",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -82,6 +83,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_no_client_server_name_no_strict",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -90,6 +92,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
strictSNI: false,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -98,6 +101,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_clientid_hostname_error",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -107,6 +111,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_invalid_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -116,6 +121,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: invalid clientid "!!!": ` +
|
||||
`bad domain name label rune '!'`,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_clientid_too_long",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -127,6 +133,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`domain name label is too long: got 72, max 63`,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "quic_clientid",
|
||||
proto: proxy.ProtoQUIC,
|
||||
@@ -135,6 +142,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_clientid_issue3437",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -144,6 +152,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "tls_case",
|
||||
proto: proxy.ProtoTLS,
|
||||
@@ -152,6 +161,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "quic_case",
|
||||
proto: proxy.ProtoQUIC,
|
||||
@@ -160,6 +170,34 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "https_no_clientid",
|
||||
proto: proxy.ProtoHTTPS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "https_clientid",
|
||||
proto: proxy.ProtoHTTPS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: false,
|
||||
}, {
|
||||
name: "https_clientid_quic",
|
||||
proto: proxy.ProtoHTTPS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
useHTTP3: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -173,16 +211,21 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
if tc.proto == proxy.ProtoTLS {
|
||||
conn = testTLSConn{
|
||||
var (
|
||||
conn net.Conn
|
||||
qconn quic.Connection
|
||||
httpReq *http.Request
|
||||
)
|
||||
|
||||
switch tc.proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
httpReq = newHTTPReq(tc.cliSrvName, tc.useHTTP3)
|
||||
case proxy.ProtoQUIC:
|
||||
qconn = testQUICConnection{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
var qconn quic.Connection
|
||||
if tc.proto == proxy.ProtoQUIC {
|
||||
qconn = testQUICConnection{
|
||||
case proxy.ProtoTLS:
|
||||
conn = testTLSConn{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
@@ -190,6 +233,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
HTTPRequest: httpReq,
|
||||
QUICConnection: qconn,
|
||||
}
|
||||
|
||||
@@ -201,60 +245,107 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newHTTPReq is a helper to create HTTP requests for tests.
|
||||
func newHTTPReq(cliSrvName string, useHTTP3 bool) (r *http.Request) {
|
||||
u := &url.URL{
|
||||
Path: "/dns-query",
|
||||
}
|
||||
|
||||
if useHTTP3 {
|
||||
return &http.Request{
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
URL: u,
|
||||
Host: cliSrvName,
|
||||
TLS: &tls.ConnectionState{},
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
URL: u,
|
||||
Host: cliSrvName,
|
||||
TLS: &tls.ConnectionState{
|
||||
ServerName: cliSrvName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
cliSrvName string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "no_clientid",
|
||||
path: "/dns-query",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "no_clientid_slash",
|
||||
path: "/dns-query/",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid",
|
||||
path: "/dns-query/cli",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid_slash",
|
||||
path: "/dns-query/cli/",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid_case",
|
||||
path: "/dns-query/InSeNsItIvE",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `clientid check: invalid path "/foo"`,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `clientid check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
}, {
|
||||
name: "invalid_clientid",
|
||||
path: "/dns-query/!!!",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `clientid check: invalid clientid "!!!": bad domain name label rune '!'`,
|
||||
}, {
|
||||
name: "both_ids",
|
||||
path: "/dns-query/right",
|
||||
cliSrvName: "wrong.example.com",
|
||||
wantClientID: "right",
|
||||
wantErrMsg: "",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
connState := &tls.ConnectionState{
|
||||
ServerName: tc.cliSrvName,
|
||||
}
|
||||
|
||||
r := &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: tc.path,
|
||||
},
|
||||
TLS: connState,
|
||||
}
|
||||
|
||||
pctx := &proxy.DNSContext{
|
||||
|
||||
@@ -97,9 +97,16 @@ type FilteringConfig struct {
|
||||
// Access settings
|
||||
// --
|
||||
|
||||
AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients
|
||||
DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked
|
||||
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||
// AllowedClients is the slice of IP addresses, CIDR networks, and ClientIDs
|
||||
// of allowed clients. If not empty, only these clients are allowed, and
|
||||
// [FilteringConfig.DisallowedClients] are ignored.
|
||||
AllowedClients []string `yaml:"allowed_clients"`
|
||||
|
||||
// DisallowedClients is the slice of IP addresses, CIDR networks, and
|
||||
// ClientIDs of disallowed clients.
|
||||
DisallowedClients []string `yaml:"disallowed_clients"`
|
||||
|
||||
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||
// TrustedProxies is the list of IP addresses and CIDR networks to detect
|
||||
// proxy servers addresses the DoH requests from which should be handled.
|
||||
// The value of nil or an empty slice for this field makes Proxy not trust
|
||||
@@ -140,13 +147,12 @@ type FilteringConfig struct {
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
cert tls.Certificate
|
||||
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
|
||||
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
|
||||
// PEM-encoded certificates chain
|
||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"`
|
||||
// PEM-encoded private key
|
||||
@@ -162,9 +168,20 @@ type TLSConfig struct {
|
||||
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate
|
||||
// DNS names from certificate (SAN) or CN value from Subject
|
||||
dnsNames []string
|
||||
|
||||
// OverrideTLSCiphers, when set, contains the names of the cipher suites to
|
||||
// use. If the slice is empty, the default safe suites are used.
|
||||
OverrideTLSCiphers []string `yaml:"override_tls_ciphers,omitempty" json:"-"`
|
||||
|
||||
// StrictSNICheck controls if the connections with SNI mismatching the
|
||||
// certificate's ones should be rejected.
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
|
||||
// hasIPAddrs is set during the certificate parsing and is true if the
|
||||
// configured certificate contains at least a single IP address.
|
||||
hasIPAddrs bool
|
||||
}
|
||||
|
||||
// DNSCryptConfig is the DNSCrypt server configuration struct.
|
||||
@@ -193,7 +210,9 @@ type ServerConfig struct {
|
||||
UpstreamTimeout time.Duration
|
||||
|
||||
TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
TLSCiphers []uint16 // list of TLS ciphers to use
|
||||
|
||||
// TLSCiphers are the IDs of TLS cipher suites to use.
|
||||
TLSCiphers []uint16
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func()
|
||||
@@ -348,17 +367,13 @@ func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
|
||||
// prepareUpstreamSettings - prepares upstream DNS server settings
|
||||
func (s *Server) prepareUpstreamSettings() error {
|
||||
// We're setting a customized set of RootCAs
|
||||
// The reason is that Go default mechanism of loading TLS roots
|
||||
// does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
// See "util.LoadSystemRootCAs"
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
|
||||
// See util.InitTLSCiphers -- removed unsafe ciphers
|
||||
if len(s.conf.TLSCiphers) > 0 {
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
}
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
@@ -451,7 +466,7 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||
func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
|
||||
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -470,31 +485,32 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||
proxyConfig.QUICListenAddr,
|
||||
)
|
||||
|
||||
var err error
|
||||
s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TLS keypair: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(s.conf.cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("x509.ParseCertificate(): %w", err)
|
||||
}
|
||||
|
||||
s.conf.hasIPAddrs = aghtls.CertificateHasIP(cert)
|
||||
|
||||
if s.conf.StrictSNICheck {
|
||||
var x *x509.Certificate
|
||||
x, err = x509.ParseCertificate(s.conf.cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("x509.ParseCertificate(): %w", err)
|
||||
}
|
||||
if len(x.DNSNames) != 0 {
|
||||
s.conf.dnsNames = x.DNSNames
|
||||
log.Debug("dns: using DNS names from certificate's SAN: %v", x.DNSNames)
|
||||
if len(cert.DNSNames) != 0 {
|
||||
s.conf.dnsNames = cert.DNSNames
|
||||
log.Debug("dnsforward: using certificate's SAN as DNS names: %v", cert.DNSNames)
|
||||
sort.Strings(s.conf.dnsNames)
|
||||
} else {
|
||||
s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName)
|
||||
log.Debug("dns: using DNS name from certificate's CN: %s", x.Subject.CommonName)
|
||||
s.conf.dnsNames = append(s.conf.dnsNames, cert.Subject.CommonName)
|
||||
log.Debug("dnsforward: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
proxyConfig.TLSConfig = &tls.Config{
|
||||
GetCertificate: s.onGetCertificate,
|
||||
CipherSuites: aghtls.SaferCipherSuites(),
|
||||
CipherSuites: s.conf.TLSCiphers,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package dnsforward
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -194,7 +195,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t *netutil.IPMap) {
|
||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
@@ -202,52 +203,54 @@ func (s *Server) setTableIPToHost(t *netutil.IPMap) {
|
||||
}
|
||||
|
||||
func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
var err error
|
||||
|
||||
add := true
|
||||
switch flags {
|
||||
case dhcpd.LeaseChangedAdded,
|
||||
dhcpd.LeaseChangedAddedStatic,
|
||||
dhcpd.LeaseChangedRemovedStatic:
|
||||
// Go on.
|
||||
case dhcpd.LeaseChangedRemovedAll:
|
||||
add = false
|
||||
s.setTableHostToIP(nil)
|
||||
s.setTableIPToHost(nil)
|
||||
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
var hostToIP hostToIPTable
|
||||
var ipToHost *netutil.IPMap
|
||||
if add {
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
hostToIP := make(hostToIPTable, len(ll))
|
||||
ipToHost := make(ipToHostTable, len(ll))
|
||||
|
||||
hostToIP = make(hostToIPTable, len(ll))
|
||||
ipToHost = netutil.NewIPMap(len(ll))
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished with the client
|
||||
// hostname validations in the DHCP server code.
|
||||
err := netutil.ValidateDomainName(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err)
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished with the client
|
||||
// hostname validations in the DHCP server code.
|
||||
err = netutil.ValidateDomainName(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug(
|
||||
"dns: skipping invalid hostname %q from dhcp: %s",
|
||||
l.Hostname,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
ip := netutil.CloneIP(l.IP)
|
||||
|
||||
ipToHost.Set(ip, lowhost)
|
||||
hostToIP[lowhost] = ip
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
|
||||
// Assume that we only process IPv4 now.
|
||||
//
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ipToHost[ip] = lowhost
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
s.setTableIPToHost(ipToHost)
|
||||
|
||||
log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost))
|
||||
}
|
||||
|
||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||
@@ -256,21 +259,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
if q.Name == ddrHostFQDN {
|
||||
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
|
||||
s.dnsProxy.QUICListenAddr == nil || q.Qtype != dns.TypeSVCB {
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
pctx.Res = s.makeDDRResponse(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -288,6 +283,10 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
if req.Question[0].Qtype != dns.TypeSVCB {
|
||||
return resp
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Think about storing the FQDN version of the server's
|
||||
// name somewhere.
|
||||
domainName := dns.Fqdn(s.conf.ServerName)
|
||||
@@ -309,20 +308,26 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.TLSListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
if s.conf.hasIPAddrs {
|
||||
// Only add DNS-over-TLS resolvers in case the certificate contains IP
|
||||
// addresses.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4927.
|
||||
for _, addr := range s.dnsProxy.TLSListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.QUICListenAddr {
|
||||
@@ -362,24 +367,13 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of
|
||||
// address since the data inside the internal table may be changed while request
|
||||
// processing. It's safe for concurrent use.
|
||||
func (s *Server) dhcpHostToIP(host string) (ip net.IP, ok bool) {
|
||||
func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
|
||||
s.tableHostToIPLock.Lock()
|
||||
defer s.tableHostToIPLock.Unlock()
|
||||
|
||||
if s.tableHostToIP == nil {
|
||||
return nil, false
|
||||
}
|
||||
ip, ok = s.tableHostToIP[host]
|
||||
|
||||
var ipFromTable net.IP
|
||||
ipFromTable, ok = s.tableHostToIP[host]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ip = make(net.IP, len(ipFromTable))
|
||||
copy(ip, ipFromTable)
|
||||
|
||||
return ip, true
|
||||
return ip, ok
|
||||
}
|
||||
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
@@ -396,7 +390,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dns: %q requests for dhcp host %q", pctx.Addr, reqHost)
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
@@ -407,18 +401,18 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
if !ok {
|
||||
// Go on and process them with filters, including dnsrewrite ones, and
|
||||
// possibly route them to a domain-specific upstream.
|
||||
log.Debug("dns: no dhcp record for %q", reqHost)
|
||||
log.Debug("dnsforward: no dhcp record for %q", reqHost)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dns: dhcp record for %q is %s", reqHost, ip)
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip)
|
||||
|
||||
resp := s.makeResponse(req)
|
||||
if q.Qtype == dns.TypeA {
|
||||
a := &dns.A{
|
||||
Hdr: s.hdr(req, dns.TypeA),
|
||||
A: ip,
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
}
|
||||
@@ -440,7 +434,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
ip, err := netutil.IPFromReversedAddr(q.Name)
|
||||
if err != nil {
|
||||
log.Debug("dns: parsing reversed addr: %s", err)
|
||||
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||
|
||||
// DNS-Based Service Discovery uses PTR records having not an ARPA
|
||||
// format of the domain name in question. Those shouldn't be
|
||||
@@ -448,12 +442,12 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
// RFC 2782.
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
if err = netutil.ValidateSRVDomainName(name); err != nil {
|
||||
log.Debug("dns: validating service domain: %s", err)
|
||||
log.Debug("dnsforward: validating service domain: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
log.Debug("dns: request is for a service domain")
|
||||
log.Debug("dnsforward: request is for a service domain")
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -462,13 +456,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be accessible externally.
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
log.Debug("dnsforward: addr %s is not from locally-served network", ip)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !dctx.isLocalClient {
|
||||
log.Debug("dns: %q requests an internal ip", pctx.Addr)
|
||||
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
@@ -492,27 +486,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for
|
||||
// concurrent use.
|
||||
func (s *Server) ipToDHCPHost(ip net.IP) (host string, ok bool) {
|
||||
func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
if s.tableIPToHost == nil {
|
||||
return "", false
|
||||
}
|
||||
host, ok = s.tableIPToHost[ip]
|
||||
|
||||
var v any
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if host, ok = v.(string); !ok {
|
||||
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return host, true
|
||||
return host, ok
|
||||
}
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
@@ -528,12 +508,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host, ok := s.ipToDHCPHost(ip)
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: bad reverse ip %v from dhcp: %s", ip, err)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host, ok := s.ipToDHCPHost(ipAddr)
|
||||
if !ok {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dns: dhcp reverse record for %s is %q", ip, host)
|
||||
log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host)
|
||||
|
||||
req := pctx.Req
|
||||
resp := s.makeResponse(req)
|
||||
@@ -638,7 +626,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
//
|
||||
// TODO(a.garipov): Route such queries to a custom upstream for the
|
||||
// local domain name if there is one.
|
||||
log.Debug("dns: dhcp client hostname %q was not filtered", reqHost)
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost)
|
||||
pctx.Res = s.genNXDomain(req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -711,13 +699,13 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
|
||||
upsConf, err := customUpsByClient(id)
|
||||
if err != nil {
|
||||
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
||||
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", id)
|
||||
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
||||
}
|
||||
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
|
||||
@@ -2,13 +2,16 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -154,19 +157,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
proxyConf := proxy.Config{}
|
||||
|
||||
if portDoT > 0 {
|
||||
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
s = &Server{
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxyConf,
|
||||
Config: proxy.Config{},
|
||||
},
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
@@ -178,8 +171,17 @@ func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled b
|
||||
},
|
||||
}
|
||||
|
||||
if portDoT > 0 {
|
||||
s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
s.conf.hasIPAddrs = true
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
if portDoH > 0 {
|
||||
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -230,12 +232,11 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantIP net.IP
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
isLocalCli bool
|
||||
}{{
|
||||
@@ -247,19 +248,19 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
}, {
|
||||
name: "local_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
name: "external_client_known_host",
|
||||
host: "example.lan",
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}, {
|
||||
name: "external_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}}
|
||||
@@ -304,7 +305,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if tc.wantIP == nil {
|
||||
if tc.wantIP == (netip.Addr{}) {
|
||||
assert.Nil(t, pctx.Res)
|
||||
} else {
|
||||
require.NotNil(t, pctx.Res)
|
||||
@@ -312,7 +313,12 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
ans := pctx.Res.Answer
|
||||
require.Len(t, ans, 1)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A)
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||
|
||||
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -324,26 +330,26 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
examplelan = "example." + defaultLocalDomainSuffix
|
||||
)
|
||||
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
suffix string
|
||||
wantIP net.IP
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
qtyp uint16
|
||||
}{{
|
||||
name: "success_external",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "success_external_non_a",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeCNAME,
|
||||
}, {
|
||||
@@ -357,14 +363,14 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
name: "success_internal_unknown",
|
||||
host: "example-new.lan",
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "success_internal_aaaa",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
@@ -423,7 +429,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
ans := pctx.Res.Answer
|
||||
require.Len(t, ans, 0)
|
||||
} else if tc.wantIP == nil {
|
||||
} else if tc.wantIP == (netip.Addr{}) {
|
||||
assert.Nil(t, pctx.Res)
|
||||
} else {
|
||||
require.NotNil(t, pctx.Res)
|
||||
@@ -431,19 +437,33 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
ans := pctx.Res.Answer
|
||||
require.Len(t, ans, 1)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A)
|
||||
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||
|
||||
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
ups := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
|
||||
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
|
||||
},
|
||||
}
|
||||
const (
|
||||
extPTRQuestion = "251.252.253.254.in-addr.arpa."
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRQuestion = "1.1.168.192.in-addr.arpa."
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
@@ -513,14 +533,20 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
const locDomain = "some.local."
|
||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
reqAddr: {locDomain},
|
||||
s := createTestServer(
|
||||
t,
|
||||
&filtering.Config{},
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
},
|
||||
})
|
||||
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
}),
|
||||
)
|
||||
|
||||
var proxyCtx *proxy.DNSContext
|
||||
var dnsCtx *dnsContext
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -43,8 +45,13 @@ var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"}
|
||||
|
||||
var webRegistered bool
|
||||
|
||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
||||
type hostToIPTable = map[string]net.IP
|
||||
// hostToIPTable is a convenient type alias for tables of host names to an IP
|
||||
// address.
|
||||
type hostToIPTable = map[string]netip.Addr
|
||||
|
||||
// ipToHostTable is a convenient type alias for tables of IP addresses to their
|
||||
// host names. For example, for use with PTR queries.
|
||||
type ipToHostTable = map[netip.Addr]string
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
@@ -63,7 +70,7 @@ type Server struct {
|
||||
dhcpServer dhcpd.Interface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
access *accessManager
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
@@ -81,7 +88,7 @@ type Server struct {
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost *netutil.IPMap
|
||||
tableIPToHost ipToHostTable
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
@@ -517,7 +524,7 @@ func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP)
|
||||
}
|
||||
|
||||
// prepareInternalProxy initializes the DNS proxy that is used for internal DNS
|
||||
// queries, such at client PTR resolving and updater hostname resolving.
|
||||
// queries, such as public clients PTR resolving and updater hostname resolving.
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
conf := &proxy.Config{
|
||||
CacheEnabled: true,
|
||||
@@ -557,16 +564,49 @@ func (s *Server) Stop() error {
|
||||
return s.stopLocked()
|
||||
}
|
||||
|
||||
// stopLocked stops the DNS server without locking. For internal use only.
|
||||
func (s *Server) stopLocked() error {
|
||||
// stopLocked stops the DNS server without locking. For internal use only.
|
||||
func (s *Server) stopLocked() (err error) {
|
||||
if s.dnsProxy != nil {
|
||||
err := s.dnsProxy.Stop()
|
||||
err = s.dnsProxy.Stop()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not stop the DNS server properly: %w", err)
|
||||
return fmt.Errorf("closing primary resolvers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.isRunning = false
|
||||
var errs []error
|
||||
|
||||
if upsConf := s.internalProxy.UpstreamConfig; upsConf != nil {
|
||||
const action = "closing internal resolvers"
|
||||
|
||||
err = upsConf.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
log.Debug("dnsforward: %s: %s", action, err)
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", action, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if upsConf := s.localResolvers.UpstreamConfig; upsConf != nil {
|
||||
const action = "closing local resolvers"
|
||||
|
||||
err = upsConf.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
log.Debug("dnsforward: %s: %s", action, err)
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", action, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.List("stopping dns server", errs...)
|
||||
} else {
|
||||
s.isRunning = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -639,27 +679,35 @@ func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
blockedByIP := false
|
||||
if ip != nil {
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: bad client ip %v: %s", ip, err)
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
blockedByIP, rule = s.access.isBlockedIP(ipAddr)
|
||||
}
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but
|
||||
// block if at least one of the checks blocks in blocklist mode.
|
||||
// Allow if at least one of the checks allows in allowlist mode, but block
|
||||
// if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||
log.Debug("client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||
log.Debug("client %v (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
if rule == "" {
|
||||
rule = clientID
|
||||
}
|
||||
|
||||
return blocked, rule
|
||||
return blocked, aghalg.Coalesce(rule, clientID)
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -161,8 +161,23 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
return s, certPem
|
||||
}
|
||||
|
||||
const googleDomainName = "google-public-dns-a.google.com."
|
||||
|
||||
func createGoogleATestMessage() *dns.Msg {
|
||||
return createTestMessage("google-public-dns-a.google.com.")
|
||||
return createTestMessage(googleDomainName)
|
||||
}
|
||||
|
||||
func newGoogleUpstream() (u upstream.Upstream) {
|
||||
return &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "google.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
func createTestMessage(host string) *dns.Msg {
|
||||
@@ -247,13 +262,7 @@ func TestServer(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -320,13 +329,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
// Message over UDP.
|
||||
@@ -343,13 +346,7 @@ func TestDoTServer(t *testing.T) {
|
||||
s, certPem := createTestTLS(t, TLSConfig{
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
// Add our self-signed generated config to roots.
|
||||
@@ -373,13 +370,7 @@ func TestDoQServer(t *testing.T) {
|
||||
s, _ := createTestTLS(t, TLSConfig{
|
||||
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
// Create a DNS-over-QUIC upstream.
|
||||
@@ -417,13 +408,7 @@ func TestServerRace(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
startDeferStop(t, s)
|
||||
|
||||
// Message over UDP.
|
||||
@@ -557,11 +542,12 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := &aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
}
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
|
||||
return &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
@@ -604,7 +590,6 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
testUpstm := &aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
IPv6: nil,
|
||||
}
|
||||
s.conf.ProtectionEnabled = false
|
||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||
@@ -931,16 +916,13 @@ func TestRewrite(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"example.org": {"somename"},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"example.org.": {{4, 3, 2, 1}},
|
||||
},
|
||||
},
|
||||
}
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
@@ -1061,11 +1043,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "34.12.168.192.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
ans := resp.Answer[0]
|
||||
assert.Equal(t, dns.TypePTR, ans.Header().Rrtype)
|
||||
assert.Equal(t, "34.12.168.192.in-addr.arpa.", ans.Header().Name)
|
||||
|
||||
ptr := testutil.RequireTypeAssert[*dns.PTR](t, ans)
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
||||
}
|
||||
|
||||
@@ -1211,12 +1194,10 @@ func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
), nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1226,12 +1207,10 @@ func TestServer_Exchange(t *testing.T) {
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
), nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -20,19 +20,13 @@ func (s *Server) filterDNSRewriteResponse(
|
||||
v rules.RRValue,
|
||||
) (ans dns.RR, err error) {
|
||||
switch rr {
|
||||
case
|
||||
dns.TypeA,
|
||||
dns.TypeAAAA:
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
return s.ansFromDNSRewriteIP(v, rr, req)
|
||||
case
|
||||
dns.TypePTR,
|
||||
dns.TypeTXT:
|
||||
case dns.TypePTR, dns.TypeTXT:
|
||||
return s.ansFromDNSRewriteText(v, rr, req)
|
||||
case dns.TypeMX:
|
||||
return s.ansFromDNSRewriteMX(v, rr, req)
|
||||
case
|
||||
dns.TypeHTTPS,
|
||||
dns.TypeSVCB:
|
||||
case dns.TypeHTTPS, dns.TypeSVCB:
|
||||
return s.ansFromDNSRewriteSVCB(v, rr, req)
|
||||
case dns.TypeSRV:
|
||||
return s.ansFromDNSRewriteSRV(v, rr, req)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -435,22 +436,22 @@ func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return useDefault, nil
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(u, "://") {
|
||||
return useDefault, errors.Error("wrong protocol")
|
||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
||||
return false, fmt.Errorf("bad protocol %q", proto)
|
||||
}
|
||||
|
||||
// Check if upstream is either an IP or IP with port.
|
||||
if net.ParseIP(u) != nil {
|
||||
return useDefault, nil
|
||||
} else if _, err = netutil.ParseIPPort(u); err != nil {
|
||||
return useDefault, err
|
||||
if _, err = netip.ParseAddr(u); err == nil {
|
||||
return false, nil
|
||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return useDefault, nil
|
||||
return false, err
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
@@ -603,6 +604,7 @@ func checkDNS(
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
|
||||
@@ -188,10 +188,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
name: "upstream_mode_fastest_addr",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: ` +
|
||||
`validating upstream "!!!": bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: validating upstream "!!!": not an ip:port`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
wantSet: `checking bootstrap a: invalid address: ` +
|
||||
@@ -297,15 +295,15 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
|
||||
wantErr: `validating upstream "dhcp://fake.dns": bad protocol "dhcp"`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
wantErr: `validating upstream "1.2.3.4.5": not an ip:port`,
|
||||
set: []string{"1.2.3.4.5"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
wantErr: `validating upstream "123.3.7m": not an ip:port`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
@@ -313,7 +311,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
|
||||
set: []string{"[host.ru]#"},
|
||||
}, {
|
||||
name: "valid_default",
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
@@ -28,7 +29,7 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
ip = netutil.CloneIP(ip)
|
||||
ip = slices.Clone(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
Reference in New Issue
Block a user