Pull request: all: allow clientid in access settings
Updates #2624. Updates #3162. Squashed commit of the following: commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:41:33 2021 +0300 all: imp types, names commit ebd4ec26636853d0d58c4e331e6a78feede20813 Merge: 239eb72116e5e09cAuthor: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:14:33 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 239eb7215abc47e99a0300a0f4cf56002689b1a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:13:10 2021 +0300 all: fix client blocking check commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13 Merge: 9935f2a39d1656b5Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 13:12:28 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448 Author: Ildar Kamalov <ik@adguard.com> Date: Tue Jun 29 11:26:51 2021 +0300 client: show block button for client id commit ed786a6a74a081cd89e9d67df3537a4fadd54831 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:56:23 2021 +0300 client: imp i18n commit 4fed21c68473ad408960c08a7d87624cabce1911 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:34:09 2021 +0300 all: imp i18n, docs commit 55e65c0d6b939560c53dcb834a4557eb3853d194 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 13:34:01 2021 +0300 all: fix cache, imp code, docs, tests commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jun 24 19:27:12 2021 +0300 all: allow clientid in access settings
This commit is contained in:
@@ -6,138 +6,163 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
lock sync.Mutex
|
||||
allowedIPs *aghnet.IPMap
|
||||
blockedIPs *aghnet.IPMap
|
||||
|
||||
// allowedClients are the IP addresses of clients in the allowlist.
|
||||
allowedClients *aghstrings.Set
|
||||
allowedClientIDs *aghstrings.Set
|
||||
blockedClientIDs *aghstrings.Set
|
||||
|
||||
// disallowedClients are the IP addresses of clients in the blocklist.
|
||||
disallowedClients *aghstrings.Set
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
||||
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
|
||||
|
||||
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// aghnet.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
}
|
||||
|
||||
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedClients: aghstrings.NewSet(),
|
||||
disallowedClients: aghstrings.NewSet(),
|
||||
}
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing allowed clients: %w", err)
|
||||
}
|
||||
// processAccessClients is a helper for processing a list of client strings,
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *aghnet.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
clientIDs *aghstrings.Set,
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
ips.Set(ip, unit{})
|
||||
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||
ipnet.IP = cidrIP
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
idErr := ValidateClientID(s)
|
||||
if idErr != nil {
|
||||
return fmt.Errorf(
|
||||
"value %q at index %d: bad ip, cidr, or clientid",
|
||||
s,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing disallowed clients: %w", err)
|
||||
}
|
||||
|
||||
b := &strings.Builder{}
|
||||
for _, s := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
|
||||
}
|
||||
|
||||
listArray := []filterlist.RuleList{}
|
||||
list := &filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
listArray = append(listArray, list)
|
||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
||||
}
|
||||
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Split array of IP or CIDR into 2 containers for fast search
|
||||
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
dst.Add(s)
|
||||
|
||||
continue
|
||||
clientIDs.Add(s)
|
||||
}
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dstIPNet = append(*dstIPNet, *ipnet)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
// Returns the item from the "disallowedClients" list that lead to blocking IP.
|
||||
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
|
||||
// but the ip does not belong to it.
|
||||
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
ipStr := ip.String()
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: aghnet.NewIPMap(0),
|
||||
blockedIPs: aghnet.NewIPMap(0),
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
|
||||
if a.allowedClients.Has(ipStr) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if len(a.allowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.allowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, ""
|
||||
allowedClientIDs: aghstrings.NewSet(),
|
||||
blockedClientIDs: aghstrings.NewSet(),
|
||||
}
|
||||
|
||||
if a.disallowedClients.Has(ipStr) {
|
||||
return true, ipStr
|
||||
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding allowed: %w", err)
|
||||
}
|
||||
|
||||
if len(a.disallowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.disallowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return true, ipnet.String()
|
||||
}
|
||||
}
|
||||
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked: %w", err)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
b := &strings.Builder{}
|
||||
for _, h := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
|
||||
}
|
||||
|
||||
lists := []filterlist.RuleList{
|
||||
&filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
},
|
||||
}
|
||||
|
||||
rulesStrg, err := filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked hosts: %w", err)
|
||||
}
|
||||
|
||||
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// IsBlockedDomain - return TRUE if this domain should be blocked
|
||||
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without client IDs
|
||||
// blocked by default.
|
||||
return allowlistMode
|
||||
}
|
||||
|
||||
if allowlistMode {
|
||||
return !a.allowedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
return a.blockedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
// isBlockedHost returns true if host should be blocked.
|
||||
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
blocked = true
|
||||
ips := a.blockedIPs
|
||||
ipnets := a.blockedNets
|
||||
|
||||
if a.allowlistMode() {
|
||||
// Enable allowlist mode and use the allowlist sets.
|
||||
blocked = false
|
||||
ips = a.allowedIPs
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips.Get(ip); ok {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
for _, ipnet := range ipnets {
|
||||
if ipnet.Contains(ip) {
|
||||
return blocked, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
return !blocked, ""
|
||||
}
|
||||
|
||||
type accessListJSON struct {
|
||||
AllowedClients []string `json:"allowed_clients"`
|
||||
DisallowedClients []string `json:"disallowed_clients"`
|
||||
@@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func checkIPCIDRArray(src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
j := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&j)
|
||||
list := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
err = checkIPCIDRArray(j.AllowedClients)
|
||||
if err == nil {
|
||||
err = checkIPCIDRArray(j.DisallowedClients)
|
||||
}
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var a *accessCtx
|
||||
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer log.Debug("Access: updated lists: %d, %d, %d",
|
||||
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
|
||||
defer log.Debug(
|
||||
"access: updated lists: %d, %d, %d",
|
||||
len(list.AllowedClients),
|
||||
len(list.DisallowedClients),
|
||||
len(list.BlockedHosts),
|
||||
)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.AllowedClients = j.AllowedClients
|
||||
s.conf.DisallowedClients = j.DisallowedClients
|
||||
s.conf.BlockedHosts = j.BlockedHosts
|
||||
s.conf.AllowedClients = list.AllowedClients
|
||||
s.conf.DisallowedClients = list.DisallowedClients
|
||||
s.conf.BlockedHosts = list.BlockedHosts
|
||||
s.access = a
|
||||
}
|
||||
|
||||
@@ -8,99 +8,23 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
const (
|
||||
ip int = iota
|
||||
cidr
|
||||
)
|
||||
func TestIsBlockedClientID(t *testing.T) {
|
||||
clientID := "client-1"
|
||||
clients := []string{clientID}
|
||||
|
||||
rules := []string{
|
||||
ip: "1.1.1.1",
|
||||
cidr: "2.2.0.0/16",
|
||||
}
|
||||
a, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowed bool
|
||||
ip net.IP
|
||||
wantDis bool
|
||||
wantRule string
|
||||
}{{
|
||||
name: "allow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[ip],
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[cidr],
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}}
|
||||
assert.False(t, a.isBlockedClientID(clientID))
|
||||
|
||||
for _, tc := range testCases {
|
||||
prefix := "allowed_"
|
||||
if !tc.allowed {
|
||||
prefix = "disallowed_"
|
||||
}
|
||||
a, err = newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run(prefix+tc.name, func(t *testing.T) {
|
||||
allowedRules := rules
|
||||
var disallowedRules []string
|
||||
|
||||
if !tc.allowed {
|
||||
allowedRules, disallowedRules = disallowedRules, allowedRules
|
||||
}
|
||||
|
||||
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantDis, disallowed)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
})
|
||||
}
|
||||
assert.True(t, a.isBlockedClientID(clientID))
|
||||
}
|
||||
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
aCtx, err := newAccessCtx(nil, nil, []string{
|
||||
func TestIsBlockedHost(t *testing.T) {
|
||||
a, err := newAccessCtx(nil, nil, []string{
|
||||
"host1",
|
||||
"*.host.com",
|
||||
"||host3.com^",
|
||||
@@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
want bool
|
||||
name string
|
||||
host string
|
||||
want bool
|
||||
}{{
|
||||
name: "plain_match",
|
||||
domain: "host1",
|
||||
want: true,
|
||||
name: "plain_match",
|
||||
host: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "plain_mismatch",
|
||||
domain: "host2",
|
||||
want: false,
|
||||
name: "plain_mismatch",
|
||||
host: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_short",
|
||||
domain: "asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_short",
|
||||
host: "asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_long",
|
||||
domain: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_long",
|
||||
host: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_no-lead",
|
||||
domain: "host.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_no_lead",
|
||||
host: "host.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_bad-asterisk",
|
||||
domain: "asdf.zhost.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_bad_asterisk",
|
||||
host: "asdf.zhost.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_simple",
|
||||
domain: "host3.com",
|
||||
want: true,
|
||||
name: "rule_match_simple",
|
||||
host: "host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_complex",
|
||||
domain: "asdf.host3.com",
|
||||
want: true,
|
||||
name: "rule_match_complex",
|
||||
host: "asdf.host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_mismatch",
|
||||
domain: ".host3.com",
|
||||
want: false,
|
||||
name: "rule_mismatch",
|
||||
host: ".host3.com",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
||||
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
clients := []string{
|
||||
"1.2.3.4",
|
||||
"5.6.7.8/24",
|
||||
}
|
||||
|
||||
allowCtx, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
blockCtx, err := newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRule string
|
||||
ip net.IP
|
||||
wantBlocked bool
|
||||
}{{
|
||||
name: "match_ip",
|
||||
wantRule: "1.2.3.4",
|
||||
ip: net.IP{1, 2, 3, 4},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "match_cidr",
|
||||
wantRule: "5.6.7.8/24",
|
||||
ip: net.IP{5, 6, 7, 100},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "no_match_ip",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 2, 3, 4},
|
||||
wantBlocked: false,
|
||||
}, {
|
||||
name: "no_match_cidr",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 6, 7, 100},
|
||||
wantBlocked: false,
|
||||
}}
|
||||
|
||||
t.Run("allow", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := allowCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, !tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("block", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := blockCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
pctx := ctx.proxyCtx
|
||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
r := pctx.HTTPRequest
|
||||
if r == nil {
|
||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx http request of proto %s is nil",
|
||||
pctx.Proto,
|
||||
)
|
||||
}
|
||||
|
||||
origPath := r.URL.Path
|
||||
@@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// Just /dns-query, no client ID.
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
case 2:
|
||||
clientID = parts[1]
|
||||
default:
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
}
|
||||
|
||||
err := ValidateClientID(clientID)
|
||||
err = ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
ctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
ctx.clientID = clientID
|
||||
|
||||
return resultCodeSuccess
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@@ -108,53 +100,73 @@ type quicSession interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// processClientID extracts the client's ID from the server name of the client's
|
||||
// DoT or DoQ request or the path of the client's DoH.
|
||||
func processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
return processClientIDHTTPS(dctx)
|
||||
return clientIDFromDNSContextHTTPS(pctx)
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
srvConf := dctx.srv.conf
|
||||
hostSrvName := srvConf.TLSConfig.ServerName
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName := ""
|
||||
if proto == proxy.ProtoTLS {
|
||||
switch proto {
|
||||
case proxy.ProtoTLS:
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx conn of proto %s is %T, want *tls.Conn",
|
||||
proto,
|
||||
conn,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
} else if proto == proxy.ProtoQUIC {
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
dctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
dctx.clientID = clientID
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientID puts the clientID into the DNS context, if there is one.
|
||||
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
clientIDData := s.clientIDCache.Get(key[:])
|
||||
if clientIDData == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
dctx.clientID = string(clientIDData)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
func TestProcessClientID(t *testing.T) {
|
||||
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
hostSrvName string
|
||||
cliSrvName string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
strictSNI bool
|
||||
}{{
|
||||
name: "udp",
|
||||
@@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_no_client_id",
|
||||
@@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name",
|
||||
@@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name_no_strict",
|
||||
@@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_client_id",
|
||||
@@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_hostname_error",
|
||||
@@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_invalid_client_id",
|
||||
@@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
@@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`label is too long, max: 63`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
@@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
@@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
|
||||
ServerName: tc.hostSrvName,
|
||||
StrictSNICheck: tc.strictSNI,
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
@@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
srv: srv,
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientID_https(t *testing.T) {
|
||||
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
}{{
|
||||
name: "no_client_id",
|
||||
path: "/dns-query",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "no_client_id_slash",
|
||||
path: "/dns-query/",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id",
|
||||
path: "/dns-query/cli",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id_slash",
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "invalid_client_id",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := clientIDFromDNSContextHTTPS(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
var uc proxy.UpstreamConfig
|
||||
var uc *proxy.UpstreamConfig
|
||||
uc, err = proxy.ParseUpstreamsConfig(
|
||||
defaultDNS,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreamConfig.Upstreams = uc.Upstreams
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = &upstreamConfig
|
||||
s.conf.UpstreamConfig = upstreamConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
s.processInternalHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
processClientID,
|
||||
s.processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
@@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
var hostToIP hostToIPTable
|
||||
var ipToHost ipToHostTable
|
||||
var ipToHost *aghnet.IPMap
|
||||
if add {
|
||||
hostToIP = make(hostToIPTable)
|
||||
ipToHost = make(ipToHostTable)
|
||||
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
|
||||
hostToIP = make(hostToIPTable, len(ll))
|
||||
ipToHost = aghnet.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
@@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
|
||||
ipToHost[l.IP.String()] = lowhost
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
@@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
host, ok = s.tableIPToHost[ip.String()]
|
||||
var v interface{}
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
|
||||
var typOK bool
|
||||
if host, typOK = v.(string); !typOK {
|
||||
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return host, ok
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
@@ -26,6 +27,11 @@ import (
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU client ID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
@@ -44,12 +50,6 @@ var webRegistered bool
|
||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
||||
type hostToIPTable = map[string]net.IP
|
||||
|
||||
// ipToHostTable is an alias for the type of Server.tableIPToHost.
|
||||
//
|
||||
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
|
||||
// places?
|
||||
type ipToHostTable = map[string]string
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -81,9 +81,13 @@ type Server struct {
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost ipToHostTable
|
||||
tableIPToHost *aghnet.IPMap
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for clientIDs that were
|
||||
// extracted during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// DNS proxy instance for internal usage
|
||||
// We don't Start() it and so no listen port is required.
|
||||
internalProxy *proxy.Proxy
|
||||
@@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
subnetDetector: p.SubnetDetector,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
@@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
|
||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||
|
||||
var upsConfig proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
})
|
||||
var upsConfig *proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||
localAddrs,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: &upsConfig,
|
||||
UpstreamConfig: upsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
if ip == nil {
|
||||
return false, ""
|
||||
// IsBlockedClient returns true if the client is blocked by the current access
|
||||
// settings.
|
||||
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but
|
||||
// block if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
return s.access.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
rule = clientID
|
||||
}
|
||||
|
||||
return blocked, rule
|
||||
}
|
||||
|
||||
@@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
net string
|
||||
proto proxy.Proto
|
||||
}{{
|
||||
name: "message_over_udp",
|
||||
net: "",
|
||||
proto: proxy.ProtoUDP,
|
||||
}, {
|
||||
name: "message_over_tcp",
|
||||
net: "tcp",
|
||||
proto: proxy.ProtoTCP,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(tc.proto)
|
||||
client := dns.Client{Net: tc.proto}
|
||||
client := dns.Client{Net: tc.net}
|
||||
|
||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
// Message over UDP.
|
||||
req := createGoogleATestMessage()
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
reply, _, err := client.Exchange(req, addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
|
||||
|
||||
// Create a DNS-over-QUIC upstream.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||
opts := upstream.Options{InsecureSkipVerify: true}
|
||||
opts := &upstream.Options{InsecureSkipVerify: true}
|
||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
|
||||
|
||||
// Message over UDP.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
||||
conn, err := dns.Dial("udp", addr.String())
|
||||
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
|
||||
|
||||
sendTestMessagesAsync(t, conn)
|
||||
@@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
@@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Net: proxy.ProtoUDP,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -11,23 +12,39 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
// processing, including logs. It performs access checks and puts the client
|
||||
// ID, if there is one, into the server's cache.
|
||||
func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(ip, clientID)
|
||||
if blocked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(d.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
|
||||
if s.access.IsBlockedDomain(host) {
|
||||
log.Tracef("domain %s is blocked by access settings", host)
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
|
||||
if s.access.isBlockedHost(host) {
|
||||
log.Debug("host %s is in access blocklist", host)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
@@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
|
||||
log.Debug("checking if dns server %q works...", input)
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, upstream.Options{
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
|
||||
@@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
clientID string
|
||||
wantLogProto querylog.ClientProto
|
||||
@@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
Reference in New Issue
Block a user