Pull request: 5035-netip-maps-access
Updates #5035. Squashed commit of the following: commit 0c9f80761419dc50d89e0e82f68cdb462569417d Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Oct 24 16:11:03 2022 +0300 dnsforward: fix access check commit df981acb4816cfba11bf6bbe4ef7796a6e365ea9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Oct 24 15:27:45 2022 +0300 dnsforward: mv access to netip.Addr
This commit is contained in:
@@ -3,25 +3,26 @@ package dnsforward
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
// TODO(e.burkov): Use map[netip.Addr]struct{} instead.
|
||||
allowedIPs *netutil.IPMap
|
||||
blockedIPs *netutil.IPMap
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// accessManager controls IP and client blocking that takes place before all
|
||||
// other processing. An accessManager is safe for concurrent use.
|
||||
type accessManager struct {
|
||||
allowedIPs map[netip.Addr]unit
|
||||
blockedIPs map[netip.Addr]unit
|
||||
|
||||
allowedClientIDs *stringutil.Set
|
||||
blockedClientIDs *stringutil.Set
|
||||
@@ -29,36 +30,29 @@ type accessCtx struct {
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// netutil.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
allowedNets []netip.Prefix
|
||||
blockedNets []netip.Prefix
|
||||
}
|
||||
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
// processAccessClients is a helper for processing a list of client strings,
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *netutil.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
ips map[netip.Addr]unit,
|
||||
nets *[]netip.Prefix,
|
||||
clientIDs *stringutil.Set,
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
ips.Set(ip, unit{})
|
||||
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||
ipnet.IP = cidrIP
|
||||
var ip netip.Addr
|
||||
var ipnet netip.Prefix
|
||||
if ip, err = netip.ParseAddr(s); err == nil {
|
||||
ips[ip] = unit{}
|
||||
} else if ipnet, err = netip.ParsePrefix(s); err == nil {
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
idErr := ValidateClientID(s)
|
||||
if idErr != nil {
|
||||
return fmt.Errorf(
|
||||
"value %q at index %d: bad ip, cidr, or clientid",
|
||||
s,
|
||||
i,
|
||||
)
|
||||
err = ValidateClientID(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
|
||||
}
|
||||
|
||||
clientIDs.Add(s)
|
||||
@@ -69,10 +63,10 @@ func processAccessClients(
|
||||
}
|
||||
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: netutil.NewIPMap(0),
|
||||
blockedIPs: netutil.NewIPMap(0),
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessManager, err error) {
|
||||
a = &accessManager{
|
||||
allowedIPs: map[netip.Addr]unit{},
|
||||
blockedIPs: map[netip.Addr]unit{},
|
||||
|
||||
allowedClientIDs: stringutil.NewSet(),
|
||||
blockedClientIDs: stringutil.NewSet(),
|
||||
@@ -112,12 +106,12 @@ func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err er
|
||||
}
|
||||
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
func (a *accessManager) allowlistMode() (ok bool) {
|
||||
return len(a.allowedIPs) != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
func (a *accessManager) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without ClientIDs blocked by
|
||||
@@ -133,7 +127,7 @@ func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
}
|
||||
|
||||
// isBlockedHost returns true if host should be blocked.
|
||||
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
func (a *accessManager) isBlockedHost(host string) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||
|
||||
return ok
|
||||
@@ -141,7 +135,7 @@ func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
|
||||
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
|
||||
blocked = true
|
||||
ips := a.blockedIPs
|
||||
ipnets := a.blockedNets
|
||||
@@ -153,7 +147,7 @@ func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips.Get(ip); ok {
|
||||
if _, ok := ips[ip]; ok {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
@@ -241,7 +235,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var a *accessCtx
|
||||
var a *accessManager
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -95,27 +95,27 @@ func TestIsBlockedIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRule string
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
wantBlocked bool
|
||||
}{{
|
||||
name: "match_ip",
|
||||
wantRule: "1.2.3.4",
|
||||
ip: net.IP{1, 2, 3, 4},
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "match_cidr",
|
||||
wantRule: "5.6.7.8/24",
|
||||
ip: net.IP{5, 6, 7, 100},
|
||||
ip: netip.MustParseAddr("5.6.7.100"),
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "no_match_ip",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 2, 3, 4},
|
||||
ip: netip.MustParseAddr("9.2.3.4"),
|
||||
wantBlocked: false,
|
||||
}, {
|
||||
name: "no_match_cidr",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 6, 7, 100},
|
||||
ip: netip.MustParseAddr("9.6.7.100"),
|
||||
wantBlocked: false,
|
||||
}}
|
||||
|
||||
|
||||
@@ -96,9 +96,16 @@ type FilteringConfig struct {
|
||||
// Access settings
|
||||
// --
|
||||
|
||||
AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients
|
||||
DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked
|
||||
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||
// AllowedClients is the slice of IP addresses, CIDR networks, and ClientIDs
|
||||
// of allowed clients. If not empty, only these clients are allowed, and
|
||||
// [FilteringConfig.DisallowedClients] are ignored.
|
||||
AllowedClients []string `yaml:"allowed_clients"`
|
||||
|
||||
// DisallowedClients is the slice of IP addresses, CIDR networks, and
|
||||
// ClientIDs of disallowed clients.
|
||||
DisallowedClients []string `yaml:"disallowed_clients"`
|
||||
|
||||
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||
// TrustedProxies is the list of IP addresses and CIDR networks to detect
|
||||
// proxy servers addresses the DoH requests from which should be handled.
|
||||
// The value of nil or an empty slice for this field makes Proxy not trust
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap].
|
||||
|
||||
// To transfer information between modules
|
||||
type dnsContext struct {
|
||||
proxyCtx *proxy.DNSContext
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap].
|
||||
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
@@ -63,7 +66,7 @@ type Server struct {
|
||||
dhcpServer dhcpd.Interface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
access *accessManager
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
@@ -673,27 +676,37 @@ func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
blockedByIP := false
|
||||
if ip != nil {
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: bad client ip %v: %s", ip, err)
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
blockedByIP, rule = s.access.isBlockedIP(ipAddr)
|
||||
}
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but
|
||||
// block if at least one of the checks blocks in blocklist mode.
|
||||
// Allow if at least one of the checks allows in allowlist mode, but block
|
||||
// if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||
log.Debug("client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||
log.Debug("client %v (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
if rule == "" {
|
||||
rule = clientID
|
||||
}
|
||||
rule = aghalg.Coalesce(rule, clientID)
|
||||
|
||||
return blocked, rule
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user