Pull request: all: allow clientid in access settings
Updates #2624. Updates #3162. Squashed commit of the following: commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:41:33 2021 +0300 all: imp types, names commit ebd4ec26636853d0d58c4e331e6a78feede20813 Merge: 239eb72116e5e09cAuthor: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:14:33 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 239eb7215abc47e99a0300a0f4cf56002689b1a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:13:10 2021 +0300 all: fix client blocking check commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13 Merge: 9935f2a39d1656b5Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 13:12:28 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448 Author: Ildar Kamalov <ik@adguard.com> Date: Tue Jun 29 11:26:51 2021 +0300 client: show block button for client id commit ed786a6a74a081cd89e9d67df3537a4fadd54831 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:56:23 2021 +0300 client: imp i18n commit 4fed21c68473ad408960c08a7d87624cabce1911 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:34:09 2021 +0300 all: imp i18n, docs commit 55e65c0d6b939560c53dcb834a4557eb3853d194 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 13:34:01 2021 +0300 all: fix cache, imp code, docs, tests commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jun 24 19:27:12 2021 +0300 all: allow clientid in access settings
This commit is contained in:
@@ -27,10 +27,9 @@ type EtcHostsContainer struct {
|
||||
lock sync.RWMutex
|
||||
// table is the host-to-IPs map.
|
||||
table map[string][]net.IP
|
||||
// tableReverse is the IP-to-hosts map.
|
||||
//
|
||||
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
|
||||
tableReverse map[string][]string
|
||||
// tableReverse is the IP-to-hosts map. The type of the values in the
|
||||
// map is []string.
|
||||
tableReverse *IPMap
|
||||
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
@@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
||||
var err error
|
||||
ehc.watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
||||
copy(ipsCopy, ips)
|
||||
}
|
||||
|
||||
log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy)
|
||||
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
||||
return ipsCopy
|
||||
}
|
||||
|
||||
@@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [
|
||||
return nil
|
||||
}
|
||||
|
||||
ipReal := UnreverseAddr(addr)
|
||||
if ipReal == nil {
|
||||
ip := UnreverseAddr(addr)
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipStr := ipReal.String()
|
||||
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
hosts = ehc.tableReverse[ipStr]
|
||||
|
||||
if len(hosts) == 0 {
|
||||
return nil // not found
|
||||
v, ok := ehc.tableReverse.Get(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts)
|
||||
hosts, ok = v.([]string)
|
||||
if !ok {
|
||||
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
||||
|
||||
return nil
|
||||
} else if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// List returns an IP-to-hostnames table. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) {
|
||||
// List returns an IP-to-hostnames table. The type of the values in the map is
|
||||
// []string. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) {
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
ipToHosts = make(map[string][]string, len(ehc.tableReverse))
|
||||
for k, v := range ehc.tableReverse {
|
||||
ipToHosts[k] = v
|
||||
}
|
||||
|
||||
return ipToHosts
|
||||
return ehc.tableReverse.ShallowClone()
|
||||
}
|
||||
|
||||
// update table
|
||||
@@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string
|
||||
ok = true
|
||||
}
|
||||
if ok {
|
||||
log.Debug("etchostscontainer: added %s -> %s", ipAddr, host)
|
||||
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
||||
}
|
||||
}
|
||||
|
||||
// updateTableRev updates the reverse address table.
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) {
|
||||
ipStr := ipAddr.String()
|
||||
hosts, ok := tableRev[ipStr]
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) {
|
||||
v, ok := tableRev.Get(ip)
|
||||
if !ok {
|
||||
tableRev[ipStr] = []string{newHost}
|
||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
||||
tableRev.Set(ip, []string{newHost})
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hosts, _ := v.([]string)
|
||||
for _, host := range hosts {
|
||||
if host == newHost {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tableRev[ipStr] = append(tableRev[ipStr], newHost)
|
||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
||||
hosts = append(hosts, newHost)
|
||||
tableRev.Set(ip, hosts)
|
||||
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
}
|
||||
|
||||
// parseHostsLine parses hosts from the fields.
|
||||
@@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) {
|
||||
// line for one IP are supported.
|
||||
func (ehc *EtcHostsContainer) load(
|
||||
table map[string][]net.IP,
|
||||
tableRev map[string][]string,
|
||||
tableRev *IPMap,
|
||||
fn string,
|
||||
) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load(
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("etchostscontainer: closing file: %s", err)
|
||||
log.Error("etchosts: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
||||
log.Debug("etchosts: loading hosts from file %s", fn)
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
@@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load(
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
}
|
||||
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Debug("etchostscontainer: modified: %s", event.Name)
|
||||
log.Debug("etchosts: modified: %s", event.Name)
|
||||
ehc.updateHosts()
|
||||
}
|
||||
|
||||
@@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
// updateHosts - loads system hosts
|
||||
func (ehc *EtcHostsContainer) updateHosts() {
|
||||
table := make(map[string][]net.IP)
|
||||
tableRev := make(map[string][]string)
|
||||
tableRev := NewIPMap(0)
|
||||
|
||||
ehc.load(table, tableRev, ehc.hostsFn)
|
||||
|
||||
@@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("etchostscontainer: Opening directory: %q: %s", dir, err)
|
||||
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
||||
}
|
||||
|
||||
continue
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("hosts_file", func(t *testing.T) {
|
||||
names, ok := ehc.List()["127.0.0.1"]
|
||||
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
})
|
||||
|
||||
112
internal/aghnet/ipmap.go
Normal file
112
internal/aghnet/ipmap.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ipArr is a representation of an IP address as an array of bytes.
|
||||
type ipArr [16]byte
|
||||
|
||||
// String implements the fmt.Stringer interface for ipArr.
|
||||
func (a ipArr) String() (s string) {
|
||||
return net.IP(a[:]).String()
|
||||
}
|
||||
|
||||
// IPMap is a map of IP addresses.
|
||||
type IPMap struct {
|
||||
m map[ipArr]interface{}
|
||||
}
|
||||
|
||||
// NewIPMap returns a new empty IP map using hint as a size hint for the
|
||||
// underlying map.
|
||||
func NewIPMap(hint int) (m *IPMap) {
|
||||
return &IPMap{
|
||||
m: make(map[ipArr]interface{}, hint),
|
||||
}
|
||||
}
|
||||
|
||||
// ipToArr converts a net.IP into an ipArr.
|
||||
//
|
||||
// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17.
|
||||
func ipToArr(ip net.IP) (a ipArr) {
|
||||
copy(a[:], ip.To16())
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just
|
||||
// like delete on an empty map doesn't.
|
||||
func (m *IPMap) Del(ip net.IP) {
|
||||
if m != nil {
|
||||
delete(m.m, ipToArr(ip))
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the value from the map. Calling Get on a nil *IPMap returns nil
|
||||
// and false, just like indexing on an empty map does.
|
||||
func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) {
|
||||
if m != nil {
|
||||
v, ok = m.m[ipToArr(ip)]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Len returns the length of the map. A nil *IPMap has a length of zero, just
|
||||
// like an empty map.
|
||||
func (m *IPMap) Len() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(m.m)
|
||||
}
|
||||
|
||||
// Range calls f for each key and value present in the map in an undefined
|
||||
// order. If cont is false, range stops the iteration. Calling Range on a nil
|
||||
// *IPMap has no effect, just like ranging over a nil map.
|
||||
func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range m.m {
|
||||
if !f(net.IP(k[:]), v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map
|
||||
// does.
|
||||
func (m *IPMap) Set(ip net.IP, v interface{}) {
|
||||
m.m[ipToArr(ip)] = v
|
||||
}
|
||||
|
||||
// ShallowClone returns a shallow clone of the map.
|
||||
func (m *IPMap) ShallowClone() (sclone *IPMap) {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sclone = NewIPMap(m.Len())
|
||||
m.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
sclone.Set(ip, v)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return sclone
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface for *IPMap.
|
||||
func (m *IPMap) String() (s string) {
|
||||
if m == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return fmt.Sprint(m.m)
|
||||
}
|
||||
142
internal/aghnet/ipmap_test.go
Normal file
142
internal/aghnet/ipmap_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPMap_allocs(t *testing.T) {
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
m := NewIPMap(0)
|
||||
m.Set(ip4, 42)
|
||||
|
||||
t.Run("get", func(t *testing.T) {
|
||||
var v interface{}
|
||||
var ok bool
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
v, ok = m.Get(ip4)
|
||||
})
|
||||
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 42, v)
|
||||
|
||||
assert.Equal(t, float64(0), allocs)
|
||||
})
|
||||
|
||||
t.Run("len", func(t *testing.T) {
|
||||
var n int
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
n = m.Len()
|
||||
})
|
||||
|
||||
require.Equal(t, 1, n)
|
||||
|
||||
assert.Equal(t, float64(0), allocs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPMap(t *testing.T) {
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
ip6 := net.IP{
|
||||
0x12, 0x34, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x56, 0x78,
|
||||
}
|
||||
|
||||
val := 42
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
var m *IPMap
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
m.Del(ip4)
|
||||
m.Del(ip6)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
v, ok := m.Get(ip4)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
|
||||
v, ok = m.Get(ip6)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
assert.Equal(t, 0, m.Len())
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
n := 0
|
||||
m.Range(func(_ net.IP, _ interface{}) (cont bool) {
|
||||
n++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, 0, n)
|
||||
})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
m.Set(ip4, val)
|
||||
})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
m.Set(ip6, val)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
sclone := m.ShallowClone()
|
||||
assert.Nil(t, sclone)
|
||||
})
|
||||
})
|
||||
|
||||
testIPMap := func(t *testing.T, ip net.IP, s string) {
|
||||
m := NewIPMap(0)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
|
||||
v, ok := m.Get(ip)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
|
||||
m.Set(ip, val)
|
||||
v, ok = m.Get(ip)
|
||||
assert.Equal(t, val, v)
|
||||
assert.True(t, ok)
|
||||
|
||||
n := 0
|
||||
m.Range(func(ipKey net.IP, v interface{}) (cont bool) {
|
||||
assert.Equal(t, ip.To16(), ipKey)
|
||||
assert.Equal(t, val, v)
|
||||
|
||||
n++
|
||||
|
||||
return false
|
||||
})
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
sclone := m.ShallowClone()
|
||||
assert.Equal(t, m, sclone)
|
||||
|
||||
assert.Equal(t, s, m.String())
|
||||
|
||||
m.Del(ip)
|
||||
v, ok = m.Get(ip)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
}
|
||||
|
||||
t.Run("ipv4", func(t *testing.T) {
|
||||
testIPMap(t, ip4, "map[1.2.3.4:42]")
|
||||
})
|
||||
|
||||
t.Run("ipv6", func(t *testing.T) {
|
||||
testIPMap(t, ip6, "map[1234::5678:42]")
|
||||
})
|
||||
}
|
||||
@@ -6,138 +6,163 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
lock sync.Mutex
|
||||
allowedIPs *aghnet.IPMap
|
||||
blockedIPs *aghnet.IPMap
|
||||
|
||||
// allowedClients are the IP addresses of clients in the allowlist.
|
||||
allowedClients *aghstrings.Set
|
||||
allowedClientIDs *aghstrings.Set
|
||||
blockedClientIDs *aghstrings.Set
|
||||
|
||||
// disallowedClients are the IP addresses of clients in the blocklist.
|
||||
disallowedClients *aghstrings.Set
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
||||
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
|
||||
|
||||
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// aghnet.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
}
|
||||
|
||||
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedClients: aghstrings.NewSet(),
|
||||
disallowedClients: aghstrings.NewSet(),
|
||||
}
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing allowed clients: %w", err)
|
||||
}
|
||||
// processAccessClients is a helper for processing a list of client strings,
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *aghnet.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
clientIDs *aghstrings.Set,
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
ips.Set(ip, unit{})
|
||||
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||
ipnet.IP = cidrIP
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
idErr := ValidateClientID(s)
|
||||
if idErr != nil {
|
||||
return fmt.Errorf(
|
||||
"value %q at index %d: bad ip, cidr, or clientid",
|
||||
s,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing disallowed clients: %w", err)
|
||||
}
|
||||
|
||||
b := &strings.Builder{}
|
||||
for _, s := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
|
||||
}
|
||||
|
||||
listArray := []filterlist.RuleList{}
|
||||
list := &filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
listArray = append(listArray, list)
|
||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
||||
}
|
||||
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Split array of IP or CIDR into 2 containers for fast search
|
||||
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
dst.Add(s)
|
||||
|
||||
continue
|
||||
clientIDs.Add(s)
|
||||
}
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dstIPNet = append(*dstIPNet, *ipnet)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
// Returns the item from the "disallowedClients" list that lead to blocking IP.
|
||||
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
|
||||
// but the ip does not belong to it.
|
||||
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
ipStr := ip.String()
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: aghnet.NewIPMap(0),
|
||||
blockedIPs: aghnet.NewIPMap(0),
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
|
||||
if a.allowedClients.Has(ipStr) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if len(a.allowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.allowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, ""
|
||||
allowedClientIDs: aghstrings.NewSet(),
|
||||
blockedClientIDs: aghstrings.NewSet(),
|
||||
}
|
||||
|
||||
if a.disallowedClients.Has(ipStr) {
|
||||
return true, ipStr
|
||||
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding allowed: %w", err)
|
||||
}
|
||||
|
||||
if len(a.disallowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.disallowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return true, ipnet.String()
|
||||
}
|
||||
}
|
||||
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked: %w", err)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
b := &strings.Builder{}
|
||||
for _, h := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
|
||||
}
|
||||
|
||||
lists := []filterlist.RuleList{
|
||||
&filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
},
|
||||
}
|
||||
|
||||
rulesStrg, err := filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked hosts: %w", err)
|
||||
}
|
||||
|
||||
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// IsBlockedDomain - return TRUE if this domain should be blocked
|
||||
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without client IDs
|
||||
// blocked by default.
|
||||
return allowlistMode
|
||||
}
|
||||
|
||||
if allowlistMode {
|
||||
return !a.allowedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
return a.blockedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
// isBlockedHost returns true if host should be blocked.
|
||||
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
blocked = true
|
||||
ips := a.blockedIPs
|
||||
ipnets := a.blockedNets
|
||||
|
||||
if a.allowlistMode() {
|
||||
// Enable allowlist mode and use the allowlist sets.
|
||||
blocked = false
|
||||
ips = a.allowedIPs
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips.Get(ip); ok {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
for _, ipnet := range ipnets {
|
||||
if ipnet.Contains(ip) {
|
||||
return blocked, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
return !blocked, ""
|
||||
}
|
||||
|
||||
type accessListJSON struct {
|
||||
AllowedClients []string `json:"allowed_clients"`
|
||||
DisallowedClients []string `json:"disallowed_clients"`
|
||||
@@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func checkIPCIDRArray(src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
j := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&j)
|
||||
list := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
err = checkIPCIDRArray(j.AllowedClients)
|
||||
if err == nil {
|
||||
err = checkIPCIDRArray(j.DisallowedClients)
|
||||
}
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var a *accessCtx
|
||||
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer log.Debug("Access: updated lists: %d, %d, %d",
|
||||
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
|
||||
defer log.Debug(
|
||||
"access: updated lists: %d, %d, %d",
|
||||
len(list.AllowedClients),
|
||||
len(list.DisallowedClients),
|
||||
len(list.BlockedHosts),
|
||||
)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.AllowedClients = j.AllowedClients
|
||||
s.conf.DisallowedClients = j.DisallowedClients
|
||||
s.conf.BlockedHosts = j.BlockedHosts
|
||||
s.conf.AllowedClients = list.AllowedClients
|
||||
s.conf.DisallowedClients = list.DisallowedClients
|
||||
s.conf.BlockedHosts = list.BlockedHosts
|
||||
s.access = a
|
||||
}
|
||||
|
||||
@@ -8,99 +8,23 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
const (
|
||||
ip int = iota
|
||||
cidr
|
||||
)
|
||||
func TestIsBlockedClientID(t *testing.T) {
|
||||
clientID := "client-1"
|
||||
clients := []string{clientID}
|
||||
|
||||
rules := []string{
|
||||
ip: "1.1.1.1",
|
||||
cidr: "2.2.0.0/16",
|
||||
}
|
||||
a, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowed bool
|
||||
ip net.IP
|
||||
wantDis bool
|
||||
wantRule string
|
||||
}{{
|
||||
name: "allow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[ip],
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[cidr],
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}}
|
||||
assert.False(t, a.isBlockedClientID(clientID))
|
||||
|
||||
for _, tc := range testCases {
|
||||
prefix := "allowed_"
|
||||
if !tc.allowed {
|
||||
prefix = "disallowed_"
|
||||
}
|
||||
a, err = newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run(prefix+tc.name, func(t *testing.T) {
|
||||
allowedRules := rules
|
||||
var disallowedRules []string
|
||||
|
||||
if !tc.allowed {
|
||||
allowedRules, disallowedRules = disallowedRules, allowedRules
|
||||
}
|
||||
|
||||
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantDis, disallowed)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
})
|
||||
}
|
||||
assert.True(t, a.isBlockedClientID(clientID))
|
||||
}
|
||||
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
aCtx, err := newAccessCtx(nil, nil, []string{
|
||||
func TestIsBlockedHost(t *testing.T) {
|
||||
a, err := newAccessCtx(nil, nil, []string{
|
||||
"host1",
|
||||
"*.host.com",
|
||||
"||host3.com^",
|
||||
@@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
want bool
|
||||
name string
|
||||
host string
|
||||
want bool
|
||||
}{{
|
||||
name: "plain_match",
|
||||
domain: "host1",
|
||||
want: true,
|
||||
name: "plain_match",
|
||||
host: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "plain_mismatch",
|
||||
domain: "host2",
|
||||
want: false,
|
||||
name: "plain_mismatch",
|
||||
host: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_short",
|
||||
domain: "asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_short",
|
||||
host: "asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_long",
|
||||
domain: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_long",
|
||||
host: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_no-lead",
|
||||
domain: "host.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_no_lead",
|
||||
host: "host.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_bad-asterisk",
|
||||
domain: "asdf.zhost.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_bad_asterisk",
|
||||
host: "asdf.zhost.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_simple",
|
||||
domain: "host3.com",
|
||||
want: true,
|
||||
name: "rule_match_simple",
|
||||
host: "host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_complex",
|
||||
domain: "asdf.host3.com",
|
||||
want: true,
|
||||
name: "rule_match_complex",
|
||||
host: "asdf.host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_mismatch",
|
||||
domain: ".host3.com",
|
||||
want: false,
|
||||
name: "rule_mismatch",
|
||||
host: ".host3.com",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
||||
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
clients := []string{
|
||||
"1.2.3.4",
|
||||
"5.6.7.8/24",
|
||||
}
|
||||
|
||||
allowCtx, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
blockCtx, err := newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRule string
|
||||
ip net.IP
|
||||
wantBlocked bool
|
||||
}{{
|
||||
name: "match_ip",
|
||||
wantRule: "1.2.3.4",
|
||||
ip: net.IP{1, 2, 3, 4},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "match_cidr",
|
||||
wantRule: "5.6.7.8/24",
|
||||
ip: net.IP{5, 6, 7, 100},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "no_match_ip",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 2, 3, 4},
|
||||
wantBlocked: false,
|
||||
}, {
|
||||
name: "no_match_cidr",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 6, 7, 100},
|
||||
wantBlocked: false,
|
||||
}}
|
||||
|
||||
t.Run("allow", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := allowCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, !tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("block", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := blockCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
pctx := ctx.proxyCtx
|
||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
r := pctx.HTTPRequest
|
||||
if r == nil {
|
||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx http request of proto %s is nil",
|
||||
pctx.Proto,
|
||||
)
|
||||
}
|
||||
|
||||
origPath := r.URL.Path
|
||||
@@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// Just /dns-query, no client ID.
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
case 2:
|
||||
clientID = parts[1]
|
||||
default:
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
}
|
||||
|
||||
err := ValidateClientID(clientID)
|
||||
err = ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
ctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
ctx.clientID = clientID
|
||||
|
||||
return resultCodeSuccess
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@@ -108,53 +100,73 @@ type quicSession interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// processClientID extracts the client's ID from the server name of the client's
|
||||
// DoT or DoQ request or the path of the client's DoH.
|
||||
func processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
return processClientIDHTTPS(dctx)
|
||||
return clientIDFromDNSContextHTTPS(pctx)
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
srvConf := dctx.srv.conf
|
||||
hostSrvName := srvConf.TLSConfig.ServerName
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName := ""
|
||||
if proto == proxy.ProtoTLS {
|
||||
switch proto {
|
||||
case proxy.ProtoTLS:
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx conn of proto %s is %T, want *tls.Conn",
|
||||
proto,
|
||||
conn,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
} else if proto == proxy.ProtoQUIC {
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
dctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
dctx.clientID = clientID
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientID puts the clientID into the DNS context, if there is one.
|
||||
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
clientIDData := s.clientIDCache.Get(key[:])
|
||||
if clientIDData == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
dctx.clientID = string(clientIDData)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
func TestProcessClientID(t *testing.T) {
|
||||
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
hostSrvName string
|
||||
cliSrvName string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
strictSNI bool
|
||||
}{{
|
||||
name: "udp",
|
||||
@@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_no_client_id",
|
||||
@@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name",
|
||||
@@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name_no_strict",
|
||||
@@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_client_id",
|
||||
@@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_hostname_error",
|
||||
@@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_invalid_client_id",
|
||||
@@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
@@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`label is too long, max: 63`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
@@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
@@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
|
||||
ServerName: tc.hostSrvName,
|
||||
StrictSNICheck: tc.strictSNI,
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
@@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
srv: srv,
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientID_https(t *testing.T) {
|
||||
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
}{{
|
||||
name: "no_client_id",
|
||||
path: "/dns-query",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "no_client_id_slash",
|
||||
path: "/dns-query/",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id",
|
||||
path: "/dns-query/cli",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id_slash",
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "invalid_client_id",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := clientIDFromDNSContextHTTPS(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
var uc proxy.UpstreamConfig
|
||||
var uc *proxy.UpstreamConfig
|
||||
uc, err = proxy.ParseUpstreamsConfig(
|
||||
defaultDNS,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreamConfig.Upstreams = uc.Upstreams
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = &upstreamConfig
|
||||
s.conf.UpstreamConfig = upstreamConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
s.processInternalHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
processClientID,
|
||||
s.processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
@@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
var hostToIP hostToIPTable
|
||||
var ipToHost ipToHostTable
|
||||
var ipToHost *aghnet.IPMap
|
||||
if add {
|
||||
hostToIP = make(hostToIPTable)
|
||||
ipToHost = make(ipToHostTable)
|
||||
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
|
||||
hostToIP = make(hostToIPTable, len(ll))
|
||||
ipToHost = aghnet.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
@@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
|
||||
ipToHost[l.IP.String()] = lowhost
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
@@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
host, ok = s.tableIPToHost[ip.String()]
|
||||
var v interface{}
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
|
||||
var typOK bool
|
||||
if host, typOK = v.(string); !typOK {
|
||||
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return host, ok
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
@@ -26,6 +27,11 @@ import (
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU client ID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
@@ -44,12 +50,6 @@ var webRegistered bool
|
||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
||||
type hostToIPTable = map[string]net.IP
|
||||
|
||||
// ipToHostTable is an alias for the type of Server.tableIPToHost.
|
||||
//
|
||||
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
|
||||
// places?
|
||||
type ipToHostTable = map[string]string
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -81,9 +81,13 @@ type Server struct {
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost ipToHostTable
|
||||
tableIPToHost *aghnet.IPMap
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for clientIDs that were
|
||||
// extracted during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// DNS proxy instance for internal usage
|
||||
// We don't Start() it and so no listen port is required.
|
||||
internalProxy *proxy.Proxy
|
||||
@@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
subnetDetector: p.SubnetDetector,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
@@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
|
||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||
|
||||
var upsConfig proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
})
|
||||
var upsConfig *proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||
localAddrs,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: &upsConfig,
|
||||
UpstreamConfig: upsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
if ip == nil {
|
||||
return false, ""
|
||||
// IsBlockedClient returns true if the client is blocked by the current access
|
||||
// settings.
|
||||
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but
|
||||
// block if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
return s.access.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
rule = clientID
|
||||
}
|
||||
|
||||
return blocked, rule
|
||||
}
|
||||
|
||||
@@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
net string
|
||||
proto proxy.Proto
|
||||
}{{
|
||||
name: "message_over_udp",
|
||||
net: "",
|
||||
proto: proxy.ProtoUDP,
|
||||
}, {
|
||||
name: "message_over_tcp",
|
||||
net: "tcp",
|
||||
proto: proxy.ProtoTCP,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(tc.proto)
|
||||
client := dns.Client{Net: tc.proto}
|
||||
client := dns.Client{Net: tc.net}
|
||||
|
||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
// Message over UDP.
|
||||
req := createGoogleATestMessage()
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
reply, _, err := client.Exchange(req, addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
|
||||
|
||||
// Create a DNS-over-QUIC upstream.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||
opts := upstream.Options{InsecureSkipVerify: true}
|
||||
opts := &upstream.Options{InsecureSkipVerify: true}
|
||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
|
||||
|
||||
// Message over UDP.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
||||
conn, err := dns.Dial("udp", addr.String())
|
||||
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
|
||||
|
||||
sendTestMessagesAsync(t, conn)
|
||||
@@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
@@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Net: proxy.ProtoUDP,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -11,23 +12,39 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
// processing, including logs. It performs access checks and puts the client
|
||||
// ID, if there is one, into the server's cache.
|
||||
func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(ip, clientID)
|
||||
if blocked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(d.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
|
||||
if s.access.IsBlockedDomain(host) {
|
||||
log.Tracef("domain %s is blocked by access settings", host)
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
|
||||
if s.access.isBlockedHost(host) {
|
||||
log.Debug("host %s is in access blocklist", host)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
@@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
|
||||
log.Debug("checking if dns server %q works...", input)
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, upstream.Options{
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
|
||||
@@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
clientID string
|
||||
wantLogProto querylog.ClientProto
|
||||
@@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error {
|
||||
var err error
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
d.parentalServer = defaultParentalServer
|
||||
opts := upstream.Options{
|
||||
opts := &upstream.Options{
|
||||
Timeout: dnsTimeout,
|
||||
ServerIPAddrs: []net.IP{
|
||||
{94, 140, 14, 15},
|
||||
|
||||
@@ -78,10 +78,13 @@ type RuntimeClientWHOISInfo struct {
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||
// different types (string, net.IP, and so on).
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
ipToRC map[string]*RuntimeClient // IP -> runtime client
|
||||
lock sync.Mutex
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
|
||||
// ipToRC is the IP address to *RuntimeClient map.
|
||||
ipToRC *aghnet.IPMap
|
||||
|
||||
lock sync.Mutex
|
||||
|
||||
allTags *aghstrings.Set
|
||||
|
||||
@@ -109,7 +112,7 @@ func (clients *clientsContainer) Init(
|
||||
}
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipToRC = make(map[string]*RuntimeClient)
|
||||
clients.ipToRC = aghnet.NewIPMap(0)
|
||||
|
||||
clients.allTags = aghstrings.NewSet(clientTags...)
|
||||
|
||||
@@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this ID already exists.
|
||||
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
|
||||
// Exists checks if client with this IP address already exists.
|
||||
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok = clients.findLocked(id)
|
||||
_, ok = clients.findLocked(ip.String())
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.ipToRC[id]
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
for _, id := range ids {
|
||||
var name string
|
||||
whois := &querylog.ClientWHOIS{}
|
||||
ip := net.ParseIP(id)
|
||||
|
||||
c, ok := clients.Find(id)
|
||||
if ok {
|
||||
name = c.Name
|
||||
} else {
|
||||
var rc RuntimeClient
|
||||
rc, ok = clients.FindRuntimeClient(id)
|
||||
} else if ip != nil {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
whois = toQueryLogWHOIS(rc.WHOISInfo)
|
||||
}
|
||||
|
||||
ip := net.ParseIP(id)
|
||||
disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip)
|
||||
disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id)
|
||||
|
||||
return &querylog.Client{
|
||||
Name: name,
|
||||
@@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams(
|
||||
return c.upstreamConfig, nil
|
||||
}
|
||||
|
||||
var conf proxy.UpstreamConfig
|
||||
var conf *proxy.UpstreamConfig
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: config.DNS.UpstreamTimeout.Duration,
|
||||
},
|
||||
@@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.upstreamConfig = &conf
|
||||
c.upstreamConfig = conf
|
||||
|
||||
return &conf, nil
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// findLocked searches for a client by its ID. For internal use only.
|
||||
@@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
var v interface{}
|
||||
v, ok = clients.ipToRC.Get(ip)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
rc, ok = v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return rc, true
|
||||
}
|
||||
|
||||
// FindRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if ipAddr == nil {
|
||||
return RuntimeClient{}, false
|
||||
func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
if ip == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
return *rc, true
|
||||
}
|
||||
|
||||
return RuntimeClient{}, false
|
||||
return clients.findRuntimeClientLocked(ip)
|
||||
}
|
||||
|
||||
// check validates the client.
|
||||
@@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||
}
|
||||
|
||||
// SetWHOISInfo sets the WHOIS information for a client.
|
||||
func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) {
|
||||
func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip)
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
return
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
rc.WHOISInfo = wi
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
@@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI
|
||||
}
|
||||
|
||||
rc.WHOISInfo = wi
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
clients.ipToRC.Set(ip, rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
|
||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
|
||||
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.ipToRC[ip]
|
||||
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
@@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
}
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
clients.ipToRC.Set(ip, rc)
|
||||
}
|
||||
|
||||
log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC))
|
||||
log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len())
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
for k, v := range clients.ipToRC {
|
||||
if v.Source == src {
|
||||
delete(clients.ipToRC, k)
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if rc.Source == src {
|
||||
clients.ipToRC.Del(ip)
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: removed %d client aliases", n)
|
||||
}
|
||||
@@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() {
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for ip, names := range hosts {
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
names, ok := v.([]string)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %d client aliases from system hosts-file", n)
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, ln := range lines {
|
||||
open := strings.Index(ln, " (")
|
||||
close := strings.Index(ln, ") ")
|
||||
if open == -1 || close == -1 || open >= close {
|
||||
lparen := strings.Index(ln, " (")
|
||||
rparen := strings.Index(ln, ") ")
|
||||
if lparen == -1 || rparen == -1 || lparen >= rparen {
|
||||
continue
|
||||
}
|
||||
|
||||
host := ln[:open]
|
||||
ip := ln[open+2 : close]
|
||||
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
|
||||
host := ln[:lparen]
|
||||
ipStr := ln[lparen+2 : rparen]
|
||||
ip := net.ParseIP(ipStr)
|
||||
if aghnet.ValidateDomainName(host) != nil || ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
@@ -35,23 +36,27 @@ func TestClients(t *testing.T) {
|
||||
|
||||
ok, err = clients.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("1:2:3::4")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("2.2.2.2")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
||||
assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
@@ -101,8 +106,8 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
|
||||
err = clients.Update("client1", &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
@@ -113,21 +118,25 @@ func TestClients(t *testing.T) {
|
||||
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.list["client1"]
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
|
||||
require.Len(t, c.IDs, 1)
|
||||
|
||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
ok := clients.Del("client1-renamed")
|
||||
require.True(t, ok)
|
||||
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
|
||||
assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
@@ -136,37 +145,44 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
|
||||
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(ip, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP)
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.Exists(ip, ClientSourceARP))
|
||||
|
||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP))
|
||||
|
||||
ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP)
|
||||
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP))
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.Exists(ip, ClientSourceDHCP))
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
@@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
clients.SetWHOISInfo("1.1.1.255", whois)
|
||||
ip := net.IP{1, 1, 1, 255}
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.NotNil(t, v)
|
||||
|
||||
require.NotNil(t, clients.ipToRC["1.1.1.255"])
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
h := clients.ipToRC["1.1.1.255"]
|
||||
require.NotNil(t, h)
|
||||
|
||||
assert.Equal(t, h.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
ok, err := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWHOISInfo("1.1.1.1", whois)
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.NotNil(t, v)
|
||||
|
||||
require.NotNil(t, clients.ipToRC["1.1.1.1"])
|
||||
h := clients.ipToRC["1.1.1.1"]
|
||||
require.NotNil(t, h)
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, h.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 2}
|
||||
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
@@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWHOISInfo("1.1.1.2", whois)
|
||||
require.Nil(t, clients.ipToRC["1.1.1.2"])
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.Nil(t, v)
|
||||
|
||||
assert.True(t, clients.Del("client1"))
|
||||
})
|
||||
}
|
||||
@@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
@@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
t.Run("complicated", func(t *testing.T) {
|
||||
var err error
|
||||
|
||||
testIP := net.IP{1, 2, 3, 4}
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := dhcpd.ServerConfig{
|
||||
@@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: testIP,
|
||||
IP: ip,
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
@@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{testIP.String()},
|
||||
IDs: []string{ip.String()},
|
||||
Name: "client2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// clientJSON is a common structure used by several handlers to deal with
|
||||
@@ -44,13 +46,13 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
IP string `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
Clients []clientJSON `json:"clients"`
|
||||
Clients []*clientJSON `json:"clients"`
|
||||
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
@@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
for ip, rc := range clients.ipToRC {
|
||||
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
cj := runtimeClientJSON{
|
||||
IP: ip,
|
||||
Name: rc.Host,
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
|
||||
Name: rc.Host,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
@@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
data.Tags = clientTags
|
||||
|
||||
@@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) {
|
||||
}
|
||||
|
||||
// Convert Client object to JSON
|
||||
func clientToJSON(c *Client) clientJSON {
|
||||
cj := clientJSON{
|
||||
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
return &clientJSON{
|
||||
Name: c.Name,
|
||||
IDs: c.IDs,
|
||||
Tags: c.Tags,
|
||||
@@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON {
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// runtimeClientToJSON converts a RuntimeClient into a JSON struct.
|
||||
func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) {
|
||||
cj = clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{ip},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
@@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
// Get the list of clients by IP address list
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]clientJSON{}
|
||||
data := []map[string]*clientJSON{}
|
||||
for i := 0; i < len(q); i++ {
|
||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if idStr == "" {
|
||||
@@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
|
||||
ip := net.ParseIP(idStr)
|
||||
c, ok := clients.Find(idStr)
|
||||
var cj clientJSON
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
var found bool
|
||||
cj, found = clients.findRuntime(ip, idStr)
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
data = append(data, map[string]clientJSON{
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: cj,
|
||||
})
|
||||
}
|
||||
@@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
// /etc/hosts tables, DHCP leases, or blocklists.
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) {
|
||||
if ip == nil {
|
||||
return cj, false
|
||||
}
|
||||
|
||||
rc, ok := clients.FindRuntimeClient(idStr)
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
return clientJSON{}, false
|
||||
}
|
||||
|
||||
cj = clientJSON{
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj = &clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
DisallowedRule: &rule,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
}
|
||||
|
||||
return cj, true
|
||||
return cj
|
||||
}
|
||||
|
||||
cj = runtimeClientToJSON(idStr, rc)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
cj = &clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj, true
|
||||
return cj
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
|
||||
@@ -105,8 +105,8 @@ func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
if ip == nil {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
|
||||
@@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port
|
||||
You have two options:
|
||||
1. Run AdGuard Home with root privileges
|
||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
|
||||
|
||||
log.Fatal(msg)
|
||||
}
|
||||
|
||||
@@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||
func (r *RDNS) Begin(ip net.IP) {
|
||||
r.ensurePrivateCache()
|
||||
|
||||
if r.isCached(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
id := ip.String()
|
||||
if r.clients.Exists(id, ClientSourceRDNS) {
|
||||
if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() {
|
||||
|
||||
// Don't handle any errors since AddHost doesn't return non-nil
|
||||
// errors for now.
|
||||
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
ipToRC: map[string]*RuntimeClient{},
|
||||
ipToRC: aghnet.NewIPMap(0),
|
||||
allTags: aghstrings.NewSet(),
|
||||
},
|
||||
}
|
||||
@@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
cc := &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: map[string]*Client{},
|
||||
ipToRC: map[string]*RuntimeClient{},
|
||||
ipToRC: aghnet.NewIPMap(0),
|
||||
allTags: aghstrings.NewSet(),
|
||||
}
|
||||
ch := make(chan net.IP)
|
||||
@@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS))
|
||||
assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
id := ip.String()
|
||||
w.clients.SetWHOISInfo(id, info)
|
||||
w.clients.SetWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
|
||||
a := convertMapToSlice(m, int(maxCount))
|
||||
d := []net.IP{}
|
||||
for _, it := range a {
|
||||
d = append(d, net.ParseIP(it.Name))
|
||||
ip := net.ParseIP(it.Name)
|
||||
if ip != nil {
|
||||
d = append(d, ip)
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user