Pull request: all: allow clientid in access settings

Updates #2624.
Updates #3162.

Squashed commit of the following:

commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:41:33 2021 +0300

    all: imp types, names

commit ebd4ec26636853d0d58c4e331e6a78feede20813
Merge: 239eb721 16e5e09c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:14:33 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 239eb7215abc47e99a0300a0f4cf56002689b1a9
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 15:13:10 2021 +0300

    all: fix client blocking check

commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13
Merge: 9935f2a3 9d1656b5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jun 29 13:12:28 2021 +0300

    Merge branch 'master' into 2624-clientid-access

commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jun 29 11:26:51 2021 +0300

    client: show block button for client id

commit ed786a6a74a081cd89e9d67df3537a4fadd54831
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:56:23 2021 +0300

    client: imp i18n

commit 4fed21c68473ad408960c08a7d87624cabce1911
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 15:34:09 2021 +0300

    all: imp i18n, docs

commit 55e65c0d6b939560c53dcb834a4557eb3853d194
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Jun 25 13:34:01 2021 +0300

    all: fix cache, imp code, docs, tests

commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 24 19:27:12 2021 +0300

    all: allow clientid in access settings
This commit is contained in:
Ainar Garipov
2021-06-29 15:53:28 +03:00
parent 16e5e09c2e
commit e08a64ebe4
33 changed files with 955 additions and 604 deletions

View File

@@ -6,138 +6,163 @@ import (
"net"
"net/http"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
)
// accessCtx controls IP and client blocking that takes place before all other
// processing. An accessCtx is safe for concurrent use.
type accessCtx struct {
lock sync.Mutex
allowedIPs *aghnet.IPMap
blockedIPs *aghnet.IPMap
// allowedClients are the IP addresses of clients in the allowlist.
allowedClients *aghstrings.Set
allowedClientIDs *aghstrings.Set
blockedClientIDs *aghstrings.Set
// disallowedClients are the IP addresses of clients in the blocklist.
disallowedClients *aghstrings.Set
blockedHostsEng *urlfilter.DNSEngine
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
// TODO(a.garipov): Create a type for a set of IP networks.
// aghnet.IPNetSet?
allowedNets []*net.IPNet
blockedNets []*net.IPNet
}
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
a = &accessCtx{
allowedClients: aghstrings.NewSet(),
disallowedClients: aghstrings.NewSet(),
}
// unit is a convenient alias for struct{}
type unit = struct{}
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
if err != nil {
return nil, fmt.Errorf("processing allowed clients: %w", err)
}
// processAccessClients is a helper for processing a list of client strings,
// which may be an IP address, a CIDR, or a ClientID.
func processAccessClients(
clientStrs []string,
ips *aghnet.IPMap,
nets *[]*net.IPNet,
clientIDs *aghstrings.Set,
) (err error) {
for i, s := range clientStrs {
if ip := net.ParseIP(s); ip != nil {
ips.Set(ip, unit{})
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
ipnet.IP = cidrIP
*nets = append(*nets, ipnet)
} else {
idErr := ValidateClientID(s)
if idErr != nil {
return fmt.Errorf(
"value %q at index %d: bad ip, cidr, or clientid",
s,
i,
)
}
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
if err != nil {
return nil, fmt.Errorf("processing disallowed clients: %w", err)
}
b := &strings.Builder{}
for _, s := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
}
listArray := []filterlist.RuleList{}
list := &filterlist.StringRuleList{
ID: int(0),
RulesText: b.String(),
IgnoreCosmetic: true,
}
listArray = append(listArray, list)
rulesStorage, err := filterlist.NewRuleStorage(listArray)
if err != nil {
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
}
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
return a, nil
}
// Split array of IP or CIDR into 2 containers for fast search
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
dst.Add(s)
continue
clientIDs.Add(s)
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
return err
}
*dstIPNet = append(*dstIPNet, *ipnet)
}
return nil
}
// IsBlockedIP - return TRUE if this client should be blocked
// Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
ipStr := ip.String()
// newAccessCtx creates a new accessCtx.
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
a = &accessCtx{
allowedIPs: aghnet.NewIPMap(0),
blockedIPs: aghnet.NewIPMap(0),
a.lock.Lock()
defer a.lock.Unlock()
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
if a.allowedClients.Has(ipStr) {
return false, ""
}
if len(a.allowedClientsIPNet) != 0 {
for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ip) {
return false, ""
}
}
}
return true, ""
allowedClientIDs: aghstrings.NewSet(),
blockedClientIDs: aghstrings.NewSet(),
}
if a.disallowedClients.Has(ipStr) {
return true, ipStr
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
if err != nil {
return nil, fmt.Errorf("adding allowed: %w", err)
}
if len(a.disallowedClientsIPNet) != 0 {
for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
if err != nil {
return nil, fmt.Errorf("adding blocked: %w", err)
}
return false, ""
b := &strings.Builder{}
for _, h := range blockedHosts {
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
}
lists := []filterlist.RuleList{
&filterlist.StringRuleList{
ID: int(0),
RulesText: b.String(),
IgnoreCosmetic: true,
},
}
rulesStrg, err := filterlist.NewRuleStorage(lists)
if err != nil {
return nil, fmt.Errorf("adding blocked hosts: %w", err)
}
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
return a, nil
}
// IsBlockedDomain - return TRUE if this domain should be blocked
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
a.lock.Lock()
defer a.lock.Unlock()
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
func (a *accessCtx) allowlistMode() (ok bool) {
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
}
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
// isBlockedClientID returns true if the ClientID should be blocked.
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
allowlistMode := a.allowlistMode()
if id == "" {
// In allowlist mode, consider requests without client IDs
// blocked by default.
return allowlistMode
}
if allowlistMode {
return !a.allowedClientIDs.Has(id)
}
return a.blockedClientIDs.Has(id)
}
// isBlockedHost returns true if host should be blocked.
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
return ok
}
// isBlockedIP returns the status of the IP address blocking as well as the rule
// that blocked it.
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
blocked = true
ips := a.blockedIPs
ipnets := a.blockedNets
if a.allowlistMode() {
// Enable allowlist mode and use the allowlist sets.
blocked = false
ips = a.allowedIPs
ipnets = a.allowedNets
}
if _, ok := ips.Get(ip); ok {
return blocked, ip.String()
}
for _, ipnet := range ipnets {
if ipnet.Contains(ip) {
return blocked, ipnet.String()
}
}
return !blocked, ""
}
type accessListJSON struct {
AllowedClients []string `json:"allowed_clients"`
DisallowedClients []string `json:"disallowed_clients"`
@@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
return
}
}
func checkIPCIDRArray(src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
continue
}
_, _, err := net.ParseCIDR(s)
if err != nil {
return err
}
}
return nil
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
j := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&j)
list := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
err = checkIPCIDRArray(j.AllowedClients)
if err == nil {
err = checkIPCIDRArray(j.DisallowedClients)
}
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
var a *accessCtx
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil {
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return
}
defer log.Debug("Access: updated lists: %d, %d, %d",
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
defer log.Debug(
"access: updated lists: %d, %d, %d",
len(list.AllowedClients),
len(list.DisallowedClients),
len(list.BlockedHosts),
)
defer s.conf.ConfigModified()
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.conf.AllowedClients = j.AllowedClients
s.conf.DisallowedClients = j.DisallowedClients
s.conf.BlockedHosts = j.BlockedHosts
s.conf.AllowedClients = list.AllowedClients
s.conf.DisallowedClients = list.DisallowedClients
s.conf.BlockedHosts = list.BlockedHosts
s.access = a
}

View File

@@ -8,99 +8,23 @@ import (
"github.com/stretchr/testify/require"
)
func TestIsBlockedIP(t *testing.T) {
const (
ip int = iota
cidr
)
func TestIsBlockedClientID(t *testing.T) {
clientID := "client-1"
clients := []string{clientID}
rules := []string{
ip: "1.1.1.1",
cidr: "2.2.0.0/16",
}
a, err := newAccessCtx(clients, nil, nil)
require.NoError(t, err)
testCases := []struct {
name string
allowed bool
ip net.IP
wantDis bool
wantRule string
}{{
name: "allow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_ip",
allowed: true,
ip: net.IPv4(1, 1, 1, 2),
wantDis: true,
wantRule: "",
}, {
name: "allow_cidr",
allowed: true,
ip: net.IPv4(2, 2, 1, 1),
wantDis: false,
wantRule: "",
}, {
name: "disallow_cidr",
allowed: true,
ip: net.IPv4(2, 3, 1, 1),
wantDis: true,
wantRule: "",
}, {
name: "allow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 1),
wantDis: true,
wantRule: rules[ip],
}, {
name: "disallow_ip",
allowed: false,
ip: net.IPv4(1, 1, 1, 2),
wantDis: false,
wantRule: "",
}, {
name: "allow_cidr",
allowed: false,
ip: net.IPv4(2, 2, 1, 1),
wantDis: true,
wantRule: rules[cidr],
}, {
name: "disallow_cidr",
allowed: false,
ip: net.IPv4(2, 3, 1, 1),
wantDis: false,
wantRule: "",
}}
assert.False(t, a.isBlockedClientID(clientID))
for _, tc := range testCases {
prefix := "allowed_"
if !tc.allowed {
prefix = "disallowed_"
}
a, err = newAccessCtx(nil, clients, nil)
require.NoError(t, err)
t.Run(prefix+tc.name, func(t *testing.T) {
allowedRules := rules
var disallowedRules []string
if !tc.allowed {
allowedRules, disallowedRules = disallowedRules, allowedRules
}
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
require.NoError(t, err)
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
assert.Equal(t, tc.wantDis, disallowed)
assert.Equal(t, tc.wantRule, rule)
})
}
assert.True(t, a.isBlockedClientID(clientID))
}
func TestIsBlockedDomain(t *testing.T) {
aCtx, err := newAccessCtx(nil, nil, []string{
func TestIsBlockedHost(t *testing.T) {
a, err := newAccessCtx(nil, nil, []string{
"host1",
"*.host.com",
"||host3.com^",
@@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
require.NoError(t, err)
testCases := []struct {
name string
domain string
want bool
name string
host string
want bool
}{{
name: "plain_match",
domain: "host1",
want: true,
name: "plain_match",
host: "host1",
want: true,
}, {
name: "plain_mismatch",
domain: "host2",
want: false,
name: "plain_mismatch",
host: "host2",
want: false,
}, {
name: "wildcard_type-1_match_short",
domain: "asdf.host.com",
want: true,
name: "subdomain_match_short",
host: "asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_match_long",
domain: "qwer.asdf.host.com",
want: true,
name: "subdomain_match_long",
host: "qwer.asdf.host.com",
want: true,
}, {
name: "wildcard_type-1_mismatch_no-lead",
domain: "host.com",
want: false,
name: "subdomain_mismatch_no_lead",
host: "host.com",
want: false,
}, {
name: "wildcard_type-1_mismatch_bad-asterisk",
domain: "asdf.zhost.com",
want: false,
name: "subdomain_mismatch_bad_asterisk",
host: "asdf.zhost.com",
want: false,
}, {
name: "wildcard_type-2_match_simple",
domain: "host3.com",
want: true,
name: "rule_match_simple",
host: "host3.com",
want: true,
}, {
name: "wildcard_type-2_match_complex",
domain: "asdf.host3.com",
want: true,
name: "rule_match_complex",
host: "asdf.host3.com",
want: true,
}, {
name: "wildcard_type-2_mismatch",
domain: ".host3.com",
want: false,
name: "rule_mismatch",
host: ".host3.com",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
})
}
}
func TestIsBlockedIP(t *testing.T) {
clients := []string{
"1.2.3.4",
"5.6.7.8/24",
}
allowCtx, err := newAccessCtx(clients, nil, nil)
require.NoError(t, err)
blockCtx, err := newAccessCtx(nil, clients, nil)
require.NoError(t, err)
testCases := []struct {
name string
wantRule string
ip net.IP
wantBlocked bool
}{{
name: "match_ip",
wantRule: "1.2.3.4",
ip: net.IP{1, 2, 3, 4},
wantBlocked: true,
}, {
name: "match_cidr",
wantRule: "5.6.7.8/24",
ip: net.IP{5, 6, 7, 100},
wantBlocked: true,
}, {
name: "no_match_ip",
wantRule: "",
ip: net.IP{9, 2, 3, 4},
wantBlocked: false,
}, {
name: "no_match_cidr",
wantRule: "",
ip: net.IP{9, 6, 7, 100},
wantBlocked: false,
}}
t.Run("allow", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := allowCtx.isBlockedIP(tc.ip)
assert.Equal(t, !tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
t.Run("block", func(t *testing.T) {
for _, tc := range testCases {
blocked, rule := blockCtx.isBlockedIP(tc.ip)
assert.Equal(t, tc.wantBlocked, blocked)
assert.Equal(t, tc.wantRule, rule)
}
})
}

View File

@@ -2,6 +2,7 @@ package dnsforward
import (
"crypto/tls"
"encoding/binary"
"fmt"
"path"
"strings"
@@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
return clientID, nil
}
// processClientIDHTTPS extracts the client's ID from the path of the
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
r := pctx.HTTPRequest
if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx http request of proto %s is nil",
pctx.Proto,
)
}
origPath := r.URL.Path
@@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
}
if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
return "", fmt.Errorf("client id check: invalid path %q", origPath)
}
clientID := ""
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
return resultCodeSuccess
return "", nil
case 2:
clientID = parts[1]
default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
}
err := ValidateClientID(clientID)
err = ValidateClientID(clientID)
if err != nil {
ctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
return "", fmt.Errorf("client id check: %w", err)
}
ctx.clientID = clientID
return resultCodeSuccess
return clientID, nil
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
@@ -108,53 +100,73 @@ type quicSession interface {
ConnectionState() (cs quic.ConnectionState)
}
// processClientID extracts the client's ID from the server name of the client's
// DoT or DoQ request or the path of the client's DoH.
func processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(dctx)
return clientIDFromDNSContextHTTPS(pctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess
return "", nil
}
srvConf := dctx.srv.conf
hostSrvName := srvConf.TLSConfig.ServerName
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return resultCodeSuccess
return "", nil
}
cliSrvName := ""
if proto == proxy.ProtoTLS {
switch proto {
case proxy.ProtoTLS:
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx conn of proto %s is %T, want *tls.Conn",
proto,
conn,
)
}
cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC {
case proxy.ProtoQUIC:
qs, ok := pctx.QUICSession.(quicSession)
if !ok {
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
return resultCodeError
return "", fmt.Errorf(
"proxy ctx quic session of proto %s is %T, want quic.Session",
proto,
pctx.QUICSession,
)
}
cliSrvName = qs.ConnectionState().TLS.ServerName
}
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
dctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
return "", fmt.Errorf("client id check: %w", err)
}
dctx.clientID = clientID
return clientID, nil
}
// processClientID puts the clientID into the DNS context, if there is one.
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
var key [8]byte
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
clientIDData := s.clientIDCache.Get(key[:])
if clientIDData == nil {
return resultCodeSuccess
}
dctx.clientID = string(clientIDData)
return resultCodeSuccess
}

View File

@@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
return cs
}
func TestProcessClientID(t *testing.T) {
func TestServer_clientIDFromDNSContext(t *testing.T) {
testCases := []struct {
name string
proto string
proto proxy.Proto
hostSrvName string
cliSrvName string
wantClientID string
wantErrMsg string
wantRes resultCode
strictSNI bool
}{{
name: "udp",
@@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_no_client_id",
@@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_no_client_server_name",
@@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: client server name "" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_no_client_server_name_no_strict",
@@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_client_id",
@@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_client_id_hostname_error",
@@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_invalid_client_id",
@@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_client_id_too_long",
@@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
`label is too long, max: 63`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "quic_client_id",
@@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}}
@@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
ServerName: tc.hostSrvName,
StrictSNICheck: tc.strictSNI,
}
srv := &Server{
conf: ServerConfig{TLSConfig: tlsConf},
}
@@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
}
}
dctx := &dnsContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
},
pctx := &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
clientID, err := srv.clientIDFromDNSContext(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err)
assert.NoError(t, err)
} else {
require.Error(t, dctx.err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}
}
func TestProcessClientID_https(t *testing.T) {
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
testCases := []struct {
name string
path string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "no_client_id",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "no_client_id_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, {
name: "invalid_client_id",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
wantRes: resultCodeError,
}}
for _, tc := range testCases {
@@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
},
pctx := &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
clientID, err := clientIDFromDNSContextHTTPS(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, dctx.err)
assert.NoError(t, err)
} else {
require.Error(t, dctx.err)
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
assert.Equal(t, tc.wantErrMsg, err.Error())
}
})
}

View File

@@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
upstreamConfig, err := proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
},
@@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
if len(upstreamConfig.Upstreams) == 0 {
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
var uc proxy.UpstreamConfig
var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig(
defaultDNS,
upstream.Options{
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
},
@@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
upstreamConfig.Upstreams = uc.Upstreams
}
s.conf.UpstreamConfig = &upstreamConfig
s.conf.UpstreamConfig = upstreamConfig
return nil
}

View File

@@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processInternalHosts,
s.processRestrictLocal,
s.processInternalIPAddrs,
processClientID,
s.processClientID,
processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
@@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIP = t
}
func (s *Server) setTableIPToHost(t ipToHostTable) {
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock()
@@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
var hostToIP hostToIPTable
var ipToHost ipToHostTable
var ipToHost *aghnet.IPMap
if add {
hostToIP = make(hostToIPTable)
ipToHost = make(ipToHostTable)
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
hostToIP = make(hostToIPTable, len(ll))
ipToHost = aghnet.NewIPMap(len(ll))
for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished
// with the client hostname validations in the DHCP
@@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
lowhost := strings.ToLower(l.Hostname)
ipToHost[l.IP.String()] = lowhost
ipToHost.Set(l.IP, lowhost)
ip := make(net.IP, 4)
copy(ip, l.IP.To4())
hostToIP[lowhost] = ip
}
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
}
s.setTableHostToIP(hostToIP)
@@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
return "", false
}
host, ok = s.tableIPToHost[ip.String()]
var v interface{}
v, ok = s.tableIPToHost.Get(ip)
var typOK bool
if host, typOK = v.(string); !typOK {
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
return "", false
}
return host, ok
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
@@ -26,6 +27,11 @@ import (
// DefaultTimeout is the default upstream timeout
const DefaultTimeout = 10 * time.Second
// defaultClientIDCacheCount is the default count of items in the LRU client ID
// cache. The assumption here is that there won't be more than this many
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
const (
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
parentalBlockHost = "family-block.dns.adguard.com"
@@ -44,12 +50,6 @@ var webRegistered bool
// hostToIPTable is an alias for the type of Server.tableHostToIP.
type hostToIPTable = map[string]net.IP
// ipToHostTable is an alias for the type of Server.tableIPToHost.
//
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
// places?
type ipToHostTable = map[string]string
// Server is the main way to start a DNS server.
//
// Example:
@@ -81,9 +81,13 @@ type Server struct {
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
tableIPToHost ipToHostTable
tableIPToHost *aghnet.IPMap
tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for clientIDs that were
// extracted during the BeforeRequestHandler stage.
clientIDCache cache.Cache
// DNS proxy instance for internal usage
// We don't Start() it and so no listen port is required.
internalProxy *proxy.Proxy
@@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
subnetDetector: p.SubnetDetector,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
MaxCount: defaultClientIDCacheCount,
}),
}
// TODO(e.burkov): Enable the refresher after the actual implementation
@@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
var upsConfig proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates?
})
var upsConfig *proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(
localAddrs,
&upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates?
},
)
if err != nil {
return fmt.Errorf("parsing upstreams: %w", err)
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &upsConfig,
UpstreamConfig: upsConfig,
},
}
@@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
// IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
if ip == nil {
return false, ""
// IsBlockedClient returns true if the client is blocked by the current access
// settings.
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
allowlistMode := s.access.allowlistMode()
blockedByIP, rule := s.access.isBlockedIP(ip)
blockedByClientID := s.access.isBlockedClientID(clientID)
// Allow if at least one of the checks allows in allowlist mode, but
// block if at least one of the checks blocks in blocklist mode.
if allowlistMode && blockedByIP && blockedByClientID {
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
// Return now without substituting the empty rule for the
// clientID because the rule can't be empty here.
return true, rule
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
blocked = true
}
return s.access.IsBlockedIP(ip)
if rule == "" {
rule = clientID
}
return blocked, rule
}

View File

@@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
testCases := []struct {
name string
proto string
net string
proto proxy.Proto
}{{
name: "message_over_udp",
net: "",
proto: proxy.ProtoUDP,
}, {
name: "message_over_tcp",
net: "tcp",
proto: proxy.ProtoTCP,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
addr := s.dnsProxy.Addr(tc.proto)
client := dns.Client{Net: tc.proto}
client := dns.Client{Net: tc.net}
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
// Message over UDP.
req := createGoogleATestMessage()
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
client := dns.Client{Net: proxy.ProtoUDP}
client := &dns.Client{}
reply, _, err := client.Exchange(req, addr.String())
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
@@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
// Create a DNS-over-QUIC upstream.
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
opts := upstream.Options{InsecureSkipVerify: true}
opts := &upstream.Options{InsecureSkipVerify: true}
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
require.NoError(t, err)
@@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
// Message over UDP.
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
conn, err := dns.Dial("udp", addr.String())
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
sendTestMessagesAsync(t, conn)
@@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
client := dns.Client{Net: proxy.ProtoUDP}
client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
@@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
// Send a DNS request without question.
_, _, err := (&dns.Client{
Net: proxy.ProtoUDP,
Timeout: 500 * time.Millisecond,
}).Exchange(&req, addr)

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"strings"
@@ -11,23 +12,39 @@ import (
"github.com/miekg/dns"
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := aghnet.IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip := aghnet.IPFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(ip, clientID)
if blocked {
return false, nil
}
if len(d.Req.Question) == 1 {
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
if s.access.IsBlockedDomain(host) {
log.Tracef("domain %s is blocked by access settings", host)
if len(pctx.Req.Question) == 1 {
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
if s.access.isBlockedHost(host) {
log.Debug("host %s is in access blocklist", host)
return false, nil
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}

View File

@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
return boot, fmt.Errorf("invalid bootstrap server address: empty")
}
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
if _, err := upstream.NewResolver(boot, nil); err != nil {
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
}
}
@@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
_, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
&upstream.Options{
Bootstrap: []string{},
Timeout: DefaultTimeout,
},
@@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
log.Debug("checking if dns server %q works...", input)
var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, upstream.Options{
u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
})

View File

@@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct {
name string
proto string
proto proxy.Proto
addr net.Addr
clientID string
wantLogProto querylog.ClientProto
@@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
require.Nil(t, err)
for _, tc := range testCases {