Merge branch 'master' into 3717-fix-qq-blocked
This commit is contained in:
75
internal/aghalg/aghalg.go
Normal file
75
internal/aghalg/aghalg.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Package aghalg contains common generic algorithms and data structures.
|
||||
//
|
||||
// TODO(a.garipov): Update to use type parameters in Go 1.18.
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// comparable is an alias for interface{}. Values passed as arguments of this
|
||||
// type alias must be comparable.
|
||||
//
|
||||
// TODO(a.garipov): Remove in Go 1.18.
|
||||
type comparable = interface{}
|
||||
|
||||
// UniqChecker allows validating uniqueness of comparable items.
|
||||
type UniqChecker map[comparable]int64
|
||||
|
||||
// Add adds a value to the validator. v must not be nil.
|
||||
func (uc UniqChecker) Add(elems ...comparable) {
|
||||
for _, e := range elems {
|
||||
uc[e]++
|
||||
}
|
||||
}
|
||||
|
||||
// Merge returns a checker containing data from both uc and other.
|
||||
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
merged = make(UniqChecker, len(uc)+len(other))
|
||||
for elem, num := range uc {
|
||||
merged[elem] += num
|
||||
}
|
||||
|
||||
for elem, num := range other {
|
||||
merged[elem] += num
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// Validate returns an error enumerating all elements that aren't unique.
|
||||
// isBefore is an optional sorting function to make the error message
|
||||
// deterministic.
|
||||
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
|
||||
var dup []comparable
|
||||
for elem, num := range uc {
|
||||
if num > 1 {
|
||||
dup = append(dup, elem)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dup) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isBefore != nil {
|
||||
sort.Slice(dup, func(i, j int) (less bool) {
|
||||
return isBefore(dup[i], dup[j])
|
||||
})
|
||||
}
|
||||
|
||||
return fmt.Errorf("duplicated values: %v", dup)
|
||||
}
|
||||
|
||||
// IntIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type int.
|
||||
func IntIsBefore(a, b comparable) (less bool) {
|
||||
return a.(int) < b.(int)
|
||||
}
|
||||
|
||||
// StringIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type string.
|
||||
func StringIsBefore(a, b comparable) (less bool) {
|
||||
return a.(string) < b.(string)
|
||||
}
|
||||
24
internal/aghhttp/aghhttp.go
Normal file
24
internal/aghhttp/aghhttp.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Package aghhttp provides some common methods to work with HTTP.
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// OK responds with word OK.
|
||||
func OK(w http.ResponseWriter) {
|
||||
if _, err := io.WriteString(w, "OK\n"); err != nil {
|
||||
log.Error("couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
@@ -11,7 +11,7 @@ type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements error interface for LimitReachedError.
|
||||
// Error implements the error interface for LimitReachedError.
|
||||
//
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
@@ -35,7 +35,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
if int64(len(p)) > lr.n {
|
||||
p = p[0:lr.n]
|
||||
p = p[:lr.n]
|
||||
}
|
||||
|
||||
n, err = lr.r.Read(p)
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitReader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
want error
|
||||
name string
|
||||
n int64
|
||||
wantErrMsg string
|
||||
name string
|
||||
n int64
|
||||
}{{
|
||||
want: nil,
|
||||
name: "positive",
|
||||
n: 1,
|
||||
wantErrMsg: "",
|
||||
name: "positive",
|
||||
n: 1,
|
||||
}, {
|
||||
want: nil,
|
||||
name: "zero",
|
||||
n: 0,
|
||||
wantErrMsg: "",
|
||||
name: "zero",
|
||||
n: 0,
|
||||
}, {
|
||||
want: fmt.Errorf("aghio: invalid n in LimitReader: -1"),
|
||||
name: "negative",
|
||||
n: -1,
|
||||
wantErrMsg: "aghio: invalid n in LimitReader: -1",
|
||||
name: "negative",
|
||||
n: -1,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := LimitReader(nil, tc.n)
|
||||
assert.Equal(t, tc.want, err)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -73,36 +73,23 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||
lreader, err := LimitReader(readCloser, tc.limit)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lreader)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||
buf := make([]byte, tc.limit+1)
|
||||
n, rerr := lreader.Read(buf)
|
||||
require.Equal(t, rerr, tc.err)
|
||||
|
||||
lreader, err := LimitReader(readCloser, tc.limit)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := lreader.Read(buf)
|
||||
require.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.want, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
||||
want string
|
||||
}{{
|
||||
err: &LimitReachedError{
|
||||
Limit: 0,
|
||||
},
|
||||
name: "simplest",
|
||||
want: "attempted to read more than 0 bytes",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, tc.err.Error())
|
||||
})
|
||||
}
|
||||
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
|
||||
Limit: 0,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ import (
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
)
|
||||
|
||||
// defaultDiscoverTime is the
|
||||
// defaultDiscoverTime is the default timeout of checking another DHCP server
|
||||
// response.
|
||||
const defaultDiscoverTime = 3 * time.Second
|
||||
|
||||
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
|
||||
@@ -110,7 +111,7 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
|
||||
// is spoiled.
|
||||
//
|
||||
// It's also known that listening on the specified interface's address
|
||||
// ignores broadcasted packets when reading.
|
||||
// ignores broadcast packets when reading.
|
||||
var c net.PacketConn
|
||||
if c, err = listenPacketReusable(iface.Name, "udp4", ":68"); err != nil {
|
||||
return false, fmt.Errorf("couldn't listen on :68: %w", err)
|
||||
|
||||
@@ -1,387 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type onChangedT func()
|
||||
|
||||
// EtcHostsContainer - automatic DNS records
|
||||
//
|
||||
// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove
|
||||
// the resolving logic.
|
||||
type EtcHostsContainer struct {
|
||||
// lock protects table and tableReverse.
|
||||
lock sync.RWMutex
|
||||
// table is the host-to-IPs map.
|
||||
table map[string][]net.IP
|
||||
// tableReverse is the IP-to-hosts map. The type of the values in the
|
||||
// map is []string.
|
||||
tableReverse *netutil.IPMap
|
||||
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
watcher *fsnotify.Watcher // file and directory watcher object
|
||||
|
||||
// onlyWritesChan used to contain only writing events from watcher.
|
||||
onlyWritesChan chan fsnotify.Event
|
||||
|
||||
onChanged onChangedT // notification to other modules
|
||||
}
|
||||
|
||||
// SetOnChanged - set callback function that will be called when the data is changed
|
||||
func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) {
|
||||
ehc.onChanged = onChanged
|
||||
}
|
||||
|
||||
// Notify other modules
|
||||
func (ehc *EtcHostsContainer) notify() {
|
||||
if ehc.onChanged == nil {
|
||||
return
|
||||
}
|
||||
ehc.onChanged()
|
||||
}
|
||||
|
||||
// Init - initialize
|
||||
// hostsFn: Override default name for the hosts-file (optional)
|
||||
func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
||||
ehc.table = make(map[string][]net.IP)
|
||||
ehc.onlyWritesChan = make(chan fsnotify.Event, 2)
|
||||
|
||||
ehc.hostsFn = "/etc/hosts"
|
||||
if runtime.GOOS == "windows" {
|
||||
ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
||||
}
|
||||
if len(hostsFn) != 0 {
|
||||
ehc.hostsFn = hostsFn
|
||||
}
|
||||
|
||||
if aghos.IsOpenWrt() {
|
||||
// OpenWrt: "/tmp/hosts/dhcp.cfg01411c".
|
||||
ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts")
|
||||
}
|
||||
|
||||
// Load hosts initially
|
||||
ehc.updateHosts()
|
||||
|
||||
var err error
|
||||
ehc.watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start - start module
|
||||
func (ehc *EtcHostsContainer) Start() {
|
||||
if ehc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Start etchostscontainer module")
|
||||
|
||||
ehc.updateHosts()
|
||||
|
||||
if ehc.watcher != nil {
|
||||
go ehc.watcherLoop()
|
||||
|
||||
err := ehc.watcher.Add(ehc.hostsFn)
|
||||
if err != nil {
|
||||
log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err)
|
||||
}
|
||||
|
||||
for _, dir := range ehc.hostsDirs {
|
||||
err = ehc.watcher.Add(dir)
|
||||
if err != nil {
|
||||
log.Error("Error while initializing watcher for a directory %s: %s", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (ehc *EtcHostsContainer) Close() {
|
||||
if ehc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ehc.watcher != nil {
|
||||
_ = ehc.watcher.Close()
|
||||
}
|
||||
|
||||
// Don't close onlyWritesChan here and let onlyWrites close it after
|
||||
// watcher.Events is closed to prevent close races.
|
||||
}
|
||||
|
||||
// Process returns the list of IP addresses for the hostname or nil if nothing
|
||||
// found.
|
||||
func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
||||
if qtype == dns.TypePTR {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ipsCopy []net.IP
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
if ips, ok := ehc.table[host]; ok {
|
||||
ipsCopy = make([]net.IP, len(ips))
|
||||
copy(ipsCopy, ips)
|
||||
}
|
||||
|
||||
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
||||
return ipsCopy
|
||||
}
|
||||
|
||||
// ProcessReverse processes a PTR request. It returns nil if nothing is found.
|
||||
func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) {
|
||||
if qtype != dns.TypePTR {
|
||||
return nil
|
||||
}
|
||||
|
||||
ip, err := netutil.IPFromReversedAddr(addr)
|
||||
if err != nil {
|
||||
log.Error("etchosts: reversed addr: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
v, ok := ehc.tableReverse.Get(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
hosts, ok = v.([]string)
|
||||
if !ok {
|
||||
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
||||
|
||||
return nil
|
||||
} else if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// List returns an IP-to-hostnames table. The type of the values in the map is
|
||||
// []string. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) {
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
return ehc.tableReverse.ShallowClone()
|
||||
}
|
||||
|
||||
// update table
|
||||
func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
|
||||
ips, ok := table[host]
|
||||
if ok {
|
||||
for _, ip := range ips {
|
||||
if ip.Equal(ipAddr) {
|
||||
// IP already exists: don't add duplicates
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
ips = append(ips, ipAddr)
|
||||
table[host] = ips
|
||||
}
|
||||
} else {
|
||||
table[host] = []net.IP{ipAddr}
|
||||
ok = true
|
||||
}
|
||||
if ok {
|
||||
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
||||
}
|
||||
}
|
||||
|
||||
// updateTableRev updates the reverse address table.
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) {
|
||||
v, ok := tableRev.Get(ip)
|
||||
if !ok {
|
||||
tableRev.Set(ip, []string{newHost})
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hosts, _ := v.([]string)
|
||||
for _, host := range hosts {
|
||||
if host == newHost {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hosts = append(hosts, newHost)
|
||||
tableRev.Set(ip, hosts)
|
||||
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
}
|
||||
|
||||
// parseHostsLine parses hosts from the fields.
|
||||
func parseHostsLine(fields []string) (hosts []string) {
|
||||
for _, f := range fields {
|
||||
hashIdx := strings.IndexByte(f, '#')
|
||||
if hashIdx == 0 {
|
||||
// The rest of the fields are a part of the comment.
|
||||
// Skip immediately.
|
||||
return
|
||||
} else if hashIdx > 0 {
|
||||
// Only a part of the field is a comment.
|
||||
hosts = append(hosts, f[:hashIdx])
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
hosts = append(hosts, f)
|
||||
}
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
|
||||
// line for one IP are supported.
|
||||
func (ehc *EtcHostsContainer) load(
|
||||
table map[string][]net.IP,
|
||||
tableRev *netutil.IPMap,
|
||||
fn string,
|
||||
) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("etchosts: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("etchosts: loading hosts from file %s", fn)
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(fields[0])
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
hosts := parseHostsLine(fields[1:])
|
||||
for _, host := range hosts {
|
||||
ehc.updateTable(table, host, ip)
|
||||
ehc.updateTableRev(tableRev, host, ip)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
||||
func (ehc *EtcHostsContainer) onlyWrites() {
|
||||
for event := range ehc.watcher.Events {
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
ehc.onlyWritesChan <- event
|
||||
}
|
||||
}
|
||||
|
||||
close(ehc.onlyWritesChan)
|
||||
}
|
||||
|
||||
// Receive notifications from fsnotify package
|
||||
func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
go ehc.onlyWrites()
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-ehc.onlyWritesChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Assume that we sometimes have the same event occurred
|
||||
// several times.
|
||||
repeat := true
|
||||
for repeat {
|
||||
select {
|
||||
case _, ok = <-ehc.onlyWritesChan:
|
||||
repeat = ok
|
||||
default:
|
||||
repeat = false
|
||||
}
|
||||
}
|
||||
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Debug("etchosts: modified: %s", event.Name)
|
||||
ehc.updateHosts()
|
||||
}
|
||||
|
||||
case err, ok := <-ehc.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateHosts - loads system hosts
|
||||
func (ehc *EtcHostsContainer) updateHosts() {
|
||||
table := make(map[string][]net.IP)
|
||||
tableRev := netutil.NewIPMap(0)
|
||||
|
||||
ehc.load(table, tableRev, ehc.hostsFn)
|
||||
|
||||
for _, dir := range ehc.hostsDirs {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, de := range des {
|
||||
ehc.load(table, tableRev, filepath.Join(dir, de.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func() {
|
||||
ehc.lock.Lock()
|
||||
defer ehc.lock.Unlock()
|
||||
|
||||
ehc.table = table
|
||||
ehc.tableReverse = tableRev
|
||||
}()
|
||||
|
||||
ehc.notify()
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func prepareTestFile(t *testing.T) (f *os.File) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
f, err := os.CreateTemp(dir, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, f)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, f.Close())
|
||||
})
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func assertWriting(t *testing.T, f *os.File, strs ...string) {
|
||||
t.Helper()
|
||||
|
||||
for _, str := range strs {
|
||||
n, err := f.WriteString(str)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, n, len(str))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtcHostsContainerResolution(t *testing.T) {
|
||||
ehc := &EtcHostsContainer{}
|
||||
|
||||
f := prepareTestFile(t)
|
||||
|
||||
assertWriting(t, f,
|
||||
" 127.0.0.1 host localhost # comment \n",
|
||||
" ::1 localhost#comment \n",
|
||||
)
|
||||
ehc.Init(f.Name())
|
||||
|
||||
t.Run("existing_host", func(t *testing.T) {
|
||||
ips := ehc.Process("localhost", dns.TypeA)
|
||||
require.Len(t, ips, 1)
|
||||
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
|
||||
})
|
||||
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
|
||||
// Comment.
|
||||
ips = ehc.Process("comment", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
t.Run("hosts_file", func(t *testing.T) {
|
||||
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
})
|
||||
|
||||
t.Run("ptr", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantIP string
|
||||
wantHost string
|
||||
wantLen int
|
||||
}{
|
||||
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
|
||||
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
a, err := dns.ReverseAddr(tc.wantIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
a = strings.TrimSuffix(a, ".")
|
||||
hosts := ehc.ProcessReverse(a, dns.TypePTR)
|
||||
require.Len(t, hosts, tc.wantLen)
|
||||
assert.Equal(t, tc.wantHost, hosts[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEtcHostsContainerFSNotify(t *testing.T) {
|
||||
ehc := &EtcHostsContainer{}
|
||||
|
||||
f := prepareTestFile(t)
|
||||
|
||||
assertWriting(t, f, " 127.0.0.1 host localhost \n")
|
||||
ehc.Init(f.Name())
|
||||
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
// Start monitoring for changes.
|
||||
ehc.Start()
|
||||
t.Cleanup(ehc.Close)
|
||||
|
||||
assertWriting(t, f, "127.0.0.2 newhost\n")
|
||||
require.NoError(t, f.Sync())
|
||||
|
||||
// Wait until fsnotify has triggered and processed the file-modification
|
||||
// event.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
t.Run("notified", func(t *testing.T) {
|
||||
ips := ehc.Process("newhost", dns.TypeA)
|
||||
assert.NotNil(t, ips)
|
||||
require.Len(t, ips, 1)
|
||||
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
|
||||
})
|
||||
}
|
||||
504
internal/aghnet/hostscontainer.go
Normal file
504
internal/aghnet/hostscontainer.go
Normal file
@@ -0,0 +1,504 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// DefaultHostsPaths returns the slice of paths default for the operating system
|
||||
// to files and directories which are containing the hosts database. The result
|
||||
// is intended to be used within fs.FS so the initial slash is omitted.
|
||||
func DefaultHostsPaths() (paths []string) {
|
||||
return defaultHostsPaths()
|
||||
}
|
||||
|
||||
// requestMatcher combines the logic for matching requests and translating the
|
||||
// appropriate rules.
|
||||
type requestMatcher struct {
|
||||
// stateLock protects all the fields of requestMatcher.
|
||||
stateLock *sync.RWMutex
|
||||
|
||||
// rulesStrg stores the rules obtained from the hosts' file.
|
||||
rulesStrg *filterlist.RuleStorage
|
||||
// engine serves rulesStrg.
|
||||
engine *urlfilter.DNSEngine
|
||||
|
||||
// translator maps generated $dnsrewrite rules into hosts-syntax rules.
|
||||
//
|
||||
// TODO(e.burkov): Store the filename from which the rule was parsed.
|
||||
translator map[string]string
|
||||
}
|
||||
|
||||
// MatchRequest processes the request rewriting hostnames and addresses read
|
||||
// from the operating system's hosts files. res is nil for any request having
|
||||
// not an A/AAAA or PTR type, see man 5 hosts.
|
||||
//
|
||||
// It's safe for concurrent use.
|
||||
func (rm *requestMatcher) MatchRequest(
|
||||
req urlfilter.DNSRequest,
|
||||
) (res *urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
log.Debug("%s: handling the request", hostsContainerPref)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
rm.stateLock.RLock()
|
||||
defer rm.stateLock.RUnlock()
|
||||
|
||||
return rm.engine.MatchRequest(req)
|
||||
}
|
||||
|
||||
// Translate returns the source hosts-syntax rule for the generated dnsrewrite
|
||||
// rule or an empty string if the last doesn't exist. The returned rules are in
|
||||
// a processed format like:
|
||||
//
|
||||
// ip host1 host2 ...
|
||||
//
|
||||
func (rm *requestMatcher) Translate(rule string) (hostRule string) {
|
||||
rm.stateLock.RLock()
|
||||
defer rm.stateLock.RUnlock()
|
||||
|
||||
return rm.translator[rule]
|
||||
}
|
||||
|
||||
// resetEng updates container's engine and the translation map.
|
||||
func (rm *requestMatcher) resetEng(rulesStrg *filterlist.RuleStorage, tr map[string]string) {
|
||||
rm.stateLock.Lock()
|
||||
defer rm.stateLock.Unlock()
|
||||
|
||||
rm.rulesStrg = rulesStrg
|
||||
rm.engine = urlfilter.NewDNSEngine(rm.rulesStrg)
|
||||
|
||||
rm.translator = tr
|
||||
}
|
||||
|
||||
// hostsContainerPref is a prefix for logging and wrapping errors in
|
||||
// HostsContainer's methods.
|
||||
const hostsContainerPref = "hosts container"
|
||||
|
||||
// HostsContainer stores the relevant hosts database provided by the OS and
|
||||
// processes both A/AAAA and PTR DNS requests for those.
|
||||
type HostsContainer struct {
|
||||
// requestMatcher matches the requests and translates the rules. It's
|
||||
// embedded to implement MatchRequest and Translate for *HostsContainer.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): Consider fully merging into HostsContainer.
|
||||
requestMatcher
|
||||
|
||||
// done is the channel to sign closing the container.
|
||||
done chan struct{}
|
||||
|
||||
// updates is the channel for receiving updated hosts.
|
||||
updates chan *netutil.IPMap
|
||||
|
||||
// last is the set of hosts that was cached within last detected change.
|
||||
last *netutil.IPMap
|
||||
|
||||
// fsys is the working file system to read hosts files from.
|
||||
fsys fs.FS
|
||||
|
||||
// w tracks the changes in specified files and directories.
|
||||
w aghos.FSWatcher
|
||||
|
||||
// patterns stores specified paths in the fs.Glob-compatible form.
|
||||
patterns []string
|
||||
|
||||
// listID is the identifier for the list of generated rules.
|
||||
listID int
|
||||
}
|
||||
|
||||
// ErrNoHostsPaths is returned when there are no valid paths to watch passed to
|
||||
// the HostsContainer.
|
||||
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
|
||||
|
||||
// NewHostsContainer creates a container of hosts, that watches the paths with
|
||||
// w. listID is used as an identifier of the underlying rules list. paths
|
||||
// shouldn't be empty and each of paths should locate either a file or a
|
||||
// directory in fsys. fsys and w must be non-nil.
|
||||
func NewHostsContainer(
|
||||
listID int,
|
||||
fsys fs.FS,
|
||||
w aghos.FSWatcher,
|
||||
paths ...string,
|
||||
) (hc *HostsContainer, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
|
||||
|
||||
if len(paths) == 0 {
|
||||
return nil, ErrNoHostsPaths
|
||||
}
|
||||
|
||||
var patterns []string
|
||||
patterns, err = pathsToPatterns(fsys, paths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(patterns) == 0 {
|
||||
return nil, ErrNoHostsPaths
|
||||
}
|
||||
|
||||
hc = &HostsContainer{
|
||||
requestMatcher: requestMatcher{
|
||||
stateLock: &sync.RWMutex{},
|
||||
},
|
||||
listID: listID,
|
||||
done: make(chan struct{}, 1),
|
||||
updates: make(chan *netutil.IPMap, 1),
|
||||
fsys: fsys,
|
||||
w: w,
|
||||
patterns: patterns,
|
||||
}
|
||||
|
||||
log.Debug("%s: starting", hostsContainerPref)
|
||||
|
||||
// Load initially.
|
||||
if err = hc.refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
if err = w.Add(p); err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("adding path: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPref, p)
|
||||
}
|
||||
}
|
||||
|
||||
go hc.handleEvents()
|
||||
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface for *HostsContainer. Close must
|
||||
// only be called once. The returned err is always nil.
|
||||
func (hc *HostsContainer) Close() (err error) {
|
||||
log.Debug("%s: closing", hostsContainerPref)
|
||||
|
||||
close(hc.done)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent. The receivable
|
||||
// map's values are guaranteed to be of type of *stringutil.Set.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||
return hc.updates
|
||||
}
|
||||
|
||||
// pathsToPatterns converts paths into patterns compatible with fs.Glob.
|
||||
func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) {
|
||||
for i, p := range paths {
|
||||
var fi fs.FileInfo
|
||||
fi, err = fs.Stat(fsys, p)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't put a filename here since it's already added by fs.Stat.
|
||||
return nil, fmt.Errorf("path at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
p = path.Join(p, "*")
|
||||
}
|
||||
|
||||
patterns = append(patterns, p)
|
||||
}
|
||||
|
||||
return patterns, nil
|
||||
}
|
||||
|
||||
// handleEvents concurrently handles the file system events. It closes the
|
||||
// update channel of HostsContainer when finishes. It's used to be called
|
||||
// within a separate goroutine.
|
||||
func (hc *HostsContainer) handleEvents() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
|
||||
|
||||
defer close(hc.updates)
|
||||
|
||||
ok, eventsCh := true, hc.w.Events()
|
||||
for ok {
|
||||
select {
|
||||
case _, ok = <-eventsCh:
|
||||
if !ok {
|
||||
log.Debug("%s: watcher closed the events channel", hostsContainerPref)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if err := hc.refresh(); err != nil {
|
||||
log.Error("%s: %s", hostsContainerPref, err)
|
||||
}
|
||||
case _, ok = <-hc.done:
|
||||
// Go on.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hostsParser is a helper type to parse rules from the operating system's hosts
|
||||
// file. It exists for only a single refreshing session.
|
||||
type hostsParser struct {
|
||||
// rulesBuilder builds the resulting rules list content.
|
||||
rulesBuilder *strings.Builder
|
||||
|
||||
// translations maps generated rules into actual hosts file lines.
|
||||
translations map[string]string
|
||||
|
||||
// table stores only the unique IP-hostname pairs. It's also sent to the
|
||||
// updates channel afterwards.
|
||||
table *netutil.IPMap
|
||||
}
|
||||
|
||||
// newHostsParser creates a new *hostsParser with buffers of size taken from the
|
||||
// previous parse.
|
||||
func (hc *HostsContainer) newHostsParser() (hp *hostsParser) {
|
||||
return &hostsParser{
|
||||
rulesBuilder: &strings.Builder{},
|
||||
translations: map[string]string{},
|
||||
table: netutil.NewIPMap(hc.last.Len()),
|
||||
}
|
||||
}
|
||||
|
||||
// parseFile is a aghos.FileWalker for parsing the files with hosts syntax. It
|
||||
// never signs to stop walking and never returns any additional patterns.
|
||||
//
|
||||
// See man hosts(5).
|
||||
func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
ip, hosts := hp.parseLine(s.Text())
|
||||
if ip == nil || len(hosts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hp.addPairs(ip, hosts)
|
||||
}
|
||||
|
||||
return nil, true, s.Err()
|
||||
}
|
||||
|
||||
// parseLine parses the line having the hosts syntax ignoring invalid ones.
|
||||
func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if ip = net.ParseIP(fields[0]); ip == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, f := range fields[1:] {
|
||||
hashIdx := strings.IndexByte(f, '#')
|
||||
if hashIdx == 0 {
|
||||
// The rest of the fields are a part of the comment so return.
|
||||
break
|
||||
} else if hashIdx > 0 {
|
||||
// Only a part of the field is a comment.
|
||||
f = f[:hashIdx]
|
||||
}
|
||||
|
||||
// Make sure that invalid hosts aren't turned into rules.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
|
||||
//
|
||||
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
|
||||
err := netutil.ValidateDomainName(f)
|
||||
if err != nil {
|
||||
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
hosts = append(hosts, f)
|
||||
}
|
||||
|
||||
return ip, hosts
|
||||
}
|
||||
|
||||
// addPair puts the pair of ip and host to the rules builder if needed. For
|
||||
// each ip the first member of hosts will become the main one.
|
||||
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
|
||||
v, ok := hp.table.Get(ip)
|
||||
if !ok {
|
||||
// This ip is added at the first time.
|
||||
v = stringutil.NewSet()
|
||||
hp.table.Set(ip, v)
|
||||
}
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
if !ok {
|
||||
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
for _, h := range hosts {
|
||||
if set.Has(h) {
|
||||
continue
|
||||
}
|
||||
|
||||
set.Add(h)
|
||||
|
||||
rule, rulePtr := hp.writeRules(h, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = processed, processed
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
|
||||
}
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the
|
||||
// host-ip pair into internal builders.
|
||||
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
|
||||
arpa, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
const (
|
||||
nl = "\n"
|
||||
|
||||
rwSuccess = "^$dnsrewrite=NOERROR;"
|
||||
rwSuccessPTR = "^$dnsrewrite=NOERROR;PTR;"
|
||||
|
||||
modLen = len(rules.MaskPipe) + len(rwSuccess) + len(";")
|
||||
modLenPTR = len(rules.MaskPipe) + len(rwSuccessPTR)
|
||||
)
|
||||
|
||||
var qtype string
|
||||
// The validation of the IP address has been performed earlier so it is
|
||||
// guaranteed to be either an IPv4 or an IPv6.
|
||||
if ip.To4() != nil {
|
||||
qtype = "A"
|
||||
} else {
|
||||
qtype = "AAAA"
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
fqdn := dns.Fqdn(host)
|
||||
|
||||
ruleBuilder := &strings.Builder{}
|
||||
ruleBuilder.Grow(modLen + len(host) + len(qtype) + len(ipStr))
|
||||
stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, host, rwSuccess, qtype, ";", ipStr)
|
||||
rule = ruleBuilder.String()
|
||||
|
||||
ruleBuilder.Reset()
|
||||
|
||||
ruleBuilder.Grow(modLenPTR + len(arpa) + len(fqdn))
|
||||
stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, arpa, rwSuccessPTR, fqdn)
|
||||
|
||||
rulePtr = ruleBuilder.String()
|
||||
|
||||
hp.rulesBuilder.Grow(len(rule) + len(rulePtr) + 2*len(nl))
|
||||
stringutil.WriteToBuilder(hp.rulesBuilder, rule, nl, rulePtr, nl)
|
||||
|
||||
return rule, rulePtr
|
||||
}
|
||||
|
||||
// equalSet returns true if the internal hosts table just parsed equals target.
|
||||
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
if target == nil {
|
||||
// hp.table shouldn't appear nil since it's initialized on each refresh.
|
||||
return target == hp.table
|
||||
}
|
||||
|
||||
if hp.table.Len() != target.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
|
||||
// ok is set to true if the target doesn't contain ip or if the
|
||||
// appropriate hosts set isn't equal to the checked one.
|
||||
if a, hasIP := target.Get(ip); !hasIP {
|
||||
ok = true
|
||||
} else if hosts, aok := a.(*stringutil.Set); aok {
|
||||
ok = !hosts.Equal(b.(*stringutil.Set))
|
||||
}
|
||||
|
||||
// Continue only if maps has no discrepancies.
|
||||
return !ok
|
||||
})
|
||||
|
||||
// Return true if every value from the IP map has no discrepancies with the
|
||||
// appropriate one from the target.
|
||||
return !ok
|
||||
}
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) {
|
||||
log.Debug("%s: sending upd", hostsContainerPref)
|
||||
|
||||
upd := hp.table
|
||||
select {
|
||||
case ch <- upd:
|
||||
// Updates are delivered. Go on.
|
||||
case <-ch:
|
||||
ch <- upd
|
||||
log.Debug("%s: replaced the last update", hostsContainerPref)
|
||||
case ch <- upd:
|
||||
// The previous update was just read and the next one pushed. Go on.
|
||||
default:
|
||||
log.Error("%s: the updates channel is broken", hostsContainerPref)
|
||||
}
|
||||
}
|
||||
|
||||
// newStrg creates a new rules storage from parsed data.
|
||||
func (hp *hostsParser) newStrg(id int) (s *filterlist.RuleStorage, err error) {
|
||||
return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{
|
||||
ID: id,
|
||||
RulesText: hp.rulesBuilder.String(),
|
||||
IgnoreCosmetic: true,
|
||||
}})
|
||||
}
|
||||
|
||||
// refresh gets the data from specified files and propagates the updates if
|
||||
// needed.
|
||||
//
|
||||
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
|
||||
func (hc *HostsContainer) refresh() (err error) {
|
||||
log.Debug("%s: refreshing", hostsContainerPref)
|
||||
|
||||
hp := hc.newHostsParser()
|
||||
if _, err = aghos.FileWalker(hp.parseFile).Walk(hc.fsys, hc.patterns...); err != nil {
|
||||
return fmt.Errorf("refreshing : %w", err)
|
||||
}
|
||||
|
||||
if hp.equalSet(hc.last) {
|
||||
log.Debug("%s: no changes detected", hostsContainerPref)
|
||||
|
||||
return nil
|
||||
}
|
||||
defer hp.sendUpd(hc.updates)
|
||||
|
||||
hc.last = hp.table.ShallowClone()
|
||||
|
||||
var rulesStrg *filterlist.RuleStorage
|
||||
if rulesStrg, err = hp.newStrg(hc.listID); err != nil {
|
||||
return fmt.Errorf("initializing rules storage: %w", err)
|
||||
}
|
||||
|
||||
hc.resetEng(rulesStrg, hp.translations)
|
||||
|
||||
return nil
|
||||
}
|
||||
18
internal/aghnet/hostscontainer_linux.go
Normal file
18
internal/aghnet/hostscontainer_linux.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
paths = []string{"etc/hosts"}
|
||||
|
||||
if aghos.IsOpenWrt() {
|
||||
paths = append(paths, "tmp/hosts")
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
||||
8
internal/aghnet/hostscontainer_others.go
Normal file
8
internal/aghnet/hostscontainer_others.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !(windows || linux)
|
||||
// +build !windows,!linux
|
||||
|
||||
package aghnet
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
return []string{"etc/hosts"}
|
||||
}
|
||||
612
internal/aghnet/hostscontainer_test.go
Normal file
612
internal/aghnet/hostscontainer_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
nl = "\n"
|
||||
sp = " "
|
||||
)
|
||||
|
||||
func TestNewHostsContainer(t *testing.T) {
|
||||
const dirname = "dir"
|
||||
const filename = "file1"
|
||||
|
||||
p := path.Join(dirname, filename)
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
wantErr error
|
||||
name string
|
||||
paths []string
|
||||
}{{
|
||||
wantErr: nil,
|
||||
name: "one_file",
|
||||
paths: []string{p},
|
||||
}, {
|
||||
wantErr: ErrNoHostsPaths,
|
||||
name: "no_files",
|
||||
paths: []string{},
|
||||
}, {
|
||||
wantErr: ErrNoHostsPaths,
|
||||
name: "non-existent_file",
|
||||
paths: []string{path.Join(dirname, filename+"2")},
|
||||
}, {
|
||||
wantErr: nil,
|
||||
name: "whole_dir",
|
||||
paths: []string{dirname},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
onAdd := func(name string) (err error) {
|
||||
assert.Contains(t, tc.paths, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
eventsCh := make(chan struct{})
|
||||
onEvents := func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return eventsCh
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||
OnEvents: onEvents,
|
||||
OnAdd: onAdd,
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}, tc.paths...)
|
||||
if tc.wantErr != nil {
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
assert.Nil(t, hc)
|
||||
|
||||
return
|
||||
}
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, hc)
|
||||
|
||||
assert.NotNil(t, <-hc.Upd())
|
||||
|
||||
eventsCh <- struct{}{}
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("nil_fs", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
|
||||
// Those shouldn't panic.
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
}, p)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("nil_watcher", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(0, testFS, nil, p)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("err_watcher", func(t *testing.T) {
|
||||
const errOnAdd errors.Error = "error"
|
||||
|
||||
errWatcher := &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||
OnAdd: func(name string) (err error) { return errOnAdd },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
|
||||
require.ErrorIs(t, err, errOnAdd)
|
||||
|
||||
assert.Nil(t, hc)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_refresh(t *testing.T) {
|
||||
// TODO(e.burkov): Test the case with no actual updates.
|
||||
|
||||
ip := net.IP{127, 0, 0, 1}
|
||||
ipStr := ip.String()
|
||||
|
||||
testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}}
|
||||
|
||||
// event is a convenient alias for an empty struct{} to emit test events.
|
||||
type event = struct{}
|
||||
|
||||
eventsCh := make(chan event, 1)
|
||||
t.Cleanup(func() { close(eventsCh) })
|
||||
|
||||
w := &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan event) { return eventsCh },
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, "dir", name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, w, "dir")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
|
||||
upd, ok := <-hc.Upd()
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
|
||||
v, ok := upd.Get(ip)
|
||||
require.True(t, ok)
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.True(t, set.Equal(wantHosts))
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, stringutil.NewSet("hostname"))
|
||||
})
|
||||
|
||||
t.Run("second_refresh", func(t *testing.T) {
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
|
||||
})
|
||||
|
||||
t.Run("double_refresh", func(t *testing.T) {
|
||||
// Make a change once.
|
||||
testFS["dir/file1"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
|
||||
// Require the changes are written.
|
||||
require.Eventually(t, func() bool {
|
||||
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: "hostname",
|
||||
DNSType: dns.TypeA,
|
||||
})
|
||||
|
||||
return !ok && res.DNSRewrites() == nil
|
||||
}, 5*time.Second, time.Second/2)
|
||||
|
||||
// Make a change again.
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}
|
||||
eventsCh <- event{}
|
||||
|
||||
// Require the changes are written.
|
||||
require.Eventually(t, func() bool {
|
||||
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: "hostname",
|
||||
DNSType: dns.TypeA,
|
||||
})
|
||||
|
||||
return !ok && res.DNSRewrites() != nil
|
||||
}, 5*time.Second, time.Second/2)
|
||||
|
||||
assert.Len(t, hc.Upd(), 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||
gsfs := fstest.MapFS{
|
||||
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
|
||||
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
|
||||
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
paths []string
|
||||
want []string
|
||||
}{{
|
||||
name: "no_paths",
|
||||
paths: nil,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "single_file",
|
||||
paths: []string{"dir_0/file_1"},
|
||||
want: []string{"dir_0/file_1"},
|
||||
}, {
|
||||
name: "several_files",
|
||||
paths: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
want: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
}, {
|
||||
name: "whole_dir",
|
||||
paths: []string{"dir_0"},
|
||||
want: []string{"dir_0/*"},
|
||||
}, {
|
||||
name: "file_and_dir",
|
||||
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
|
||||
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
|
||||
}, {
|
||||
name: "non-existing",
|
||||
paths: []string{path.Join("dir_0", "file_3")},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, patterns)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("bad_file", func(t *testing.T) {
|
||||
const errStat errors.Error = "bad file"
|
||||
|
||||
badFS := &aghtest.StatFS{
|
||||
OnStat: func(name string) (fs.FileInfo, error) {
|
||||
return nil, errStat
|
||||
},
|
||||
}
|
||||
|
||||
_, err := pathsToPatterns(badFS, []string{""})
|
||||
assert.ErrorIs(t, err, errStat)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_Translate(t *testing.T) {
|
||||
testdata := os.DirFS("./testdata")
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule string
|
||||
wantTrans []string
|
||||
}{{
|
||||
name: "simplehost",
|
||||
rule: "|simplehost^$dnsrewrite=NOERROR;A;1.0.0.1",
|
||||
wantTrans: []string{"1.0.0.1", "simplehost"},
|
||||
}, {
|
||||
name: "hello",
|
||||
rule: "|hello^$dnsrewrite=NOERROR;A;1.0.0.0",
|
||||
wantTrans: []string{"1.0.0.0", "hello", "hello.world"},
|
||||
}, {
|
||||
name: "hello-alias",
|
||||
rule: "|hello.world.again^$dnsrewrite=NOERROR;A;1.0.0.0",
|
||||
wantTrans: []string{"1.0.0.0", "hello.world.again"},
|
||||
}, {
|
||||
name: "simplehost_v6",
|
||||
rule: "|simplehost^$dnsrewrite=NOERROR;AAAA;::1",
|
||||
wantTrans: []string{"::1", "simplehost"},
|
||||
}, {
|
||||
name: "hello_v6",
|
||||
rule: "|hello^$dnsrewrite=NOERROR;AAAA;::",
|
||||
wantTrans: []string{"::", "hello", "hello.world"},
|
||||
}, {
|
||||
name: "hello_v6-alias",
|
||||
rule: "|hello.world.again^$dnsrewrite=NOERROR;AAAA;::",
|
||||
wantTrans: []string{"::", "hello.world.again"},
|
||||
}, {
|
||||
name: "simplehost_ptr",
|
||||
rule: "|1.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;simplehost.",
|
||||
wantTrans: []string{"1.0.0.1", "simplehost"},
|
||||
}, {
|
||||
name: "hello_ptr",
|
||||
rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.",
|
||||
wantTrans: []string{"1.0.0.0", "hello", "hello.world"},
|
||||
}, {
|
||||
name: "hello_ptr-alias",
|
||||
rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.world.again.",
|
||||
wantTrans: []string{"1.0.0.0", "hello.world.again"},
|
||||
}, {
|
||||
name: "simplehost_ptr_v6",
|
||||
rule: "|1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
|
||||
"^$dnsrewrite=NOERROR;PTR;simplehost.",
|
||||
wantTrans: []string{"::1", "simplehost"},
|
||||
}, {
|
||||
name: "hello_ptr_v6",
|
||||
rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
|
||||
"^$dnsrewrite=NOERROR;PTR;hello.",
|
||||
wantTrans: []string{"::", "hello", "hello.world"},
|
||||
}, {
|
||||
name: "hello_ptr_v6-alias",
|
||||
rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
|
||||
"^$dnsrewrite=NOERROR;PTR;hello.world.again.",
|
||||
wantTrans: []string{"::", "hello.world.again"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := stringutil.NewSet(strings.Fields(hc.Translate(tc.rule))...)
|
||||
assert.True(t, stringutil.NewSet(tc.wantTrans...).Equal(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsContainer(t *testing.T) {
|
||||
const listID = 1234
|
||||
|
||||
testdata := os.DirFS("./testdata")
|
||||
|
||||
testCases := []struct {
|
||||
want []*rules.DNSRewrite
|
||||
name string
|
||||
req urlfilter.DNSRequest
|
||||
}{{
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 1),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::1"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "simple",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "hello_alias",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "other_line_alias",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "hello.world.again",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_subdomain",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "say.hello",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_alias_subdomain",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "say.hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(1, 0, 0, 2),
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::2"),
|
||||
}},
|
||||
name: "lots_of_aliases",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypePTR,
|
||||
Value: "simplehost.",
|
||||
}},
|
||||
name: "reverse",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "non-existing",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "nonexisting",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: nil,
|
||||
name: "bad_type",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypeSRV,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(4, 2, 1, 6),
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::42"),
|
||||
}},
|
||||
name: "issue_4216_4_6",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(7, 5, 3, 1),
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(1, 3, 5, 7),
|
||||
}},
|
||||
name: "issue_4216_4",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::13"),
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::31"),
|
||||
}},
|
||||
name: "issue_4216_6",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
}}
|
||||
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res, ok := hc.MatchRequest(tc.req)
|
||||
require.False(t, ok)
|
||||
|
||||
if tc.want == nil {
|
||||
assert.Nil(t, res)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, res)
|
||||
|
||||
rewrites := res.DNSRewrites()
|
||||
require.Len(t, rewrites, len(tc.want))
|
||||
|
||||
for i, rewrite := range rewrites {
|
||||
require.Equal(t, listID, rewrite.FilterListID)
|
||||
|
||||
rw := rewrite.DNSRewrite
|
||||
require.NotNil(t, rw)
|
||||
|
||||
assert.Equal(t, tc.want[i], rw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
ip := net.IP{127, 0, 0, 1}
|
||||
ipStr := ip.String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
line string
|
||||
wantIP net.IP
|
||||
wantHosts []string
|
||||
}{{
|
||||
name: "simple",
|
||||
line: ipStr + ` hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "aliases",
|
||||
line: ipStr + ` hostname alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname", "alias"},
|
||||
}, {
|
||||
name: "invalid_line",
|
||||
line: ipStr,
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "invalid_line_hostname",
|
||||
line: ipStr + ` # hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "commented_aliases",
|
||||
line: ipStr + ` hostname # alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "whole_comment",
|
||||
line: `# ` + ipStr + ` hostname`,
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "partial_comment",
|
||||
line: ipStr + ` host#name`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"host"},
|
||||
}, {
|
||||
name: "empty",
|
||||
line: ``,
|
||||
wantIP: nil,
|
||||
wantHosts: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hp := hostsParser{}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, hosts := hp.parseLine(tc.line)
|
||||
assert.True(t, tc.wantIP.Equal(got))
|
||||
assert.Equal(t, tc.wantHosts, hosts)
|
||||
})
|
||||
}
|
||||
}
|
||||
33
internal/aghnet/hostscontainer_windows.go
Normal file
33
internal/aghnet/hostscontainer_windows.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
sysDir, err := windows.GetSystemDirectory()
|
||||
if err != nil {
|
||||
log.Error("getting system directory: %s", err)
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Split all the elements of the path to join them afterwards. This is
|
||||
// needed to make the Windows-specific path string returned by
|
||||
// windows.GetSystemDirectory to be compatible with fs.FS.
|
||||
pathElems := strings.Split(sysDir, string(os.PathSeparator))
|
||||
if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) {
|
||||
pathElems = pathElems[1:]
|
||||
}
|
||||
|
||||
return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)}
|
||||
}
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// IPVersion is a documentational alias for int. Use it when the integer means
|
||||
// IP version.
|
||||
// IPVersion is a alias for int for documentation purposes. Use it when the
|
||||
// integer means IP version.
|
||||
type IPVersion = int
|
||||
|
||||
// IP version constants.
|
||||
@@ -25,6 +25,13 @@ type NetIface interface {
|
||||
|
||||
// IfaceIPAddrs returns the interface's IP addresses.
|
||||
func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
|
||||
switch ipv {
|
||||
case IPVersion4, IPVersion6:
|
||||
// Go on.
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ip version %d", ipv)
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -41,20 +48,16 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Assume that net.(*Interface).Addrs can only return valid IPv4
|
||||
// and IPv6 addresses. Thus, if it isn't an IPv4 address, it
|
||||
// must be an IPv6 one.
|
||||
switch ipv {
|
||||
case IPVersion4:
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// Assume that net.(*Interface).Addrs can only return valid IPv4 and
|
||||
// IPv6 addresses. Thus, if it isn't an IPv4 address, it must be an
|
||||
// IPv6 one.
|
||||
ip4 := ip.To4()
|
||||
if ipv == IPVersion4 {
|
||||
if ip4 != nil {
|
||||
ips = append(ips, ip4)
|
||||
}
|
||||
case IPVersion6:
|
||||
if ip6 := ip.To4(); ip6 == nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ip version %d", ipv)
|
||||
} else if ip4 == nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +70,7 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
|
||||
//
|
||||
// It makes up to maxAttempts attempts to get the addresses if there are none,
|
||||
// each time using the provided backoff. Sometimes an interface needs a few
|
||||
// seconds to really ititialize.
|
||||
// seconds to really initialize.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2304.
|
||||
func IfaceDNSIPAddrs(
|
||||
@@ -92,18 +95,20 @@ func IfaceDNSIPAddrs(
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
n--
|
||||
|
||||
switch len(addrs) {
|
||||
case 0:
|
||||
// Don't return errors in case the users want to try and enable
|
||||
// the DHCP server later.
|
||||
// Don't return errors in case the users want to try and enable the DHCP
|
||||
// server later.
|
||||
t := time.Duration(n) * backoff
|
||||
log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t)
|
||||
|
||||
return nil, nil
|
||||
case 1:
|
||||
// Some Android devices use 8.8.8.8 if there is not a secondary
|
||||
// DNS server. Fix that by setting the secondary DNS address to
|
||||
// the same address.
|
||||
// Some Android devices use 8.8.8.8 if there is not a secondary DNS
|
||||
// server. Fix that by setting the secondary DNS address to the same
|
||||
// address.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/1708.
|
||||
log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv)
|
||||
@@ -116,3 +121,11 @@ func IfaceDNSIPAddrs(
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// interfaceName is a string containing network interface's name. The name is
|
||||
// used in file walking methods.
|
||||
type interfaceName string
|
||||
|
||||
// Use interfaceName in the OS-independent code since it's actually only used in
|
||||
// several OS-dependent implementations which causes linting issues.
|
||||
var _ = interfaceName("")
|
||||
|
||||
@@ -5,13 +5,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeIface is a stub implementation of aghnet.NetIface to simplify testing.
|
||||
type fakeIface struct {
|
||||
addrs []net.Addr
|
||||
err error
|
||||
addrs []net.Addr
|
||||
}
|
||||
|
||||
// Addrs implements the NetIface interface for *fakeIface.
|
||||
@@ -33,61 +35,86 @@ func TestIfaceIPAddrs(t *testing.T) {
|
||||
addr6 := &net.IPNet{IP: ip6}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
iface NetIface
|
||||
ipv IPVersion
|
||||
want []net.IP
|
||||
wantErr error
|
||||
iface NetIface
|
||||
name string
|
||||
wantErrMsg string
|
||||
want []net.IP
|
||||
ipv IPVersion
|
||||
}{{
|
||||
name: "ipv4_success",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
|
||||
ipv: IPVersion4,
|
||||
want: []net.IP{ip4},
|
||||
wantErr: nil,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
|
||||
name: "ipv4_success",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip4},
|
||||
ipv: IPVersion4,
|
||||
}, {
|
||||
name: "ipv4_success_with_ipv6",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
|
||||
ipv: IPVersion4,
|
||||
want: []net.IP{ip4},
|
||||
wantErr: nil,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
|
||||
name: "ipv4_success_with_ipv6",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip4},
|
||||
ipv: IPVersion4,
|
||||
}, {
|
||||
name: "ipv4_error",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest},
|
||||
ipv: IPVersion4,
|
||||
want: nil,
|
||||
wantErr: errTest,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest},
|
||||
name: "ipv4_error",
|
||||
wantErrMsg: errTest.Error(),
|
||||
want: nil,
|
||||
ipv: IPVersion4,
|
||||
}, {
|
||||
name: "ipv6_success",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil},
|
||||
ipv: IPVersion6,
|
||||
want: []net.IP{ip6},
|
||||
wantErr: nil,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil},
|
||||
name: "ipv6_success",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip6},
|
||||
ipv: IPVersion6,
|
||||
}, {
|
||||
name: "ipv6_success_with_ipv4",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
|
||||
ipv: IPVersion6,
|
||||
want: []net.IP{ip6},
|
||||
wantErr: nil,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
|
||||
name: "ipv6_success_with_ipv4",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip6},
|
||||
ipv: IPVersion6,
|
||||
}, {
|
||||
name: "ipv6_error",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest},
|
||||
ipv: IPVersion6,
|
||||
want: nil,
|
||||
wantErr: errTest,
|
||||
iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest},
|
||||
name: "ipv6_error",
|
||||
wantErrMsg: errTest.Error(),
|
||||
want: nil,
|
||||
ipv: IPVersion6,
|
||||
}, {
|
||||
iface: &fakeIface{addrs: nil, err: nil},
|
||||
name: "bad_proto",
|
||||
wantErrMsg: "invalid ip version 10",
|
||||
want: nil,
|
||||
ipv: IPVersion6 + IPVersion4,
|
||||
}, {
|
||||
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip4}}, err: nil},
|
||||
name: "ipaddr_v4",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip4},
|
||||
ipv: IPVersion4,
|
||||
}, {
|
||||
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip6, Zone: ""}}, err: nil},
|
||||
name: "ipaddr_v6",
|
||||
wantErrMsg: "",
|
||||
want: []net.IP{ip6},
|
||||
ipv: IPVersion6,
|
||||
}, {
|
||||
iface: &fakeIface{addrs: []net.Addr{&net.UnixAddr{}}, err: nil},
|
||||
name: "non-ipv4",
|
||||
wantErrMsg: "",
|
||||
want: nil,
|
||||
ipv: IPVersion4,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, gotErr := IfaceIPAddrs(tc.iface, tc.ipv)
|
||||
require.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
got, err := IfaceIPAddrs(tc.iface, tc.ipv)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type waitingFakeIface struct {
|
||||
addrs []net.Addr
|
||||
err error
|
||||
addrs []net.Addr
|
||||
n int
|
||||
}
|
||||
|
||||
@@ -116,11 +143,11 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
addr6 := &net.IPNet{IP: ip6}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
iface NetIface
|
||||
ipv IPVersion
|
||||
want []net.IP
|
||||
wantErr error
|
||||
name string
|
||||
want []net.IP
|
||||
ipv IPVersion
|
||||
}{{
|
||||
name: "ipv4_success",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
|
||||
@@ -169,12 +196,25 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
ipv: IPVersion6,
|
||||
want: []net.IP{ip6, ip6},
|
||||
wantErr: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
iface: &fakeIface{addrs: nil, err: nil},
|
||||
ipv: IPVersion4,
|
||||
want: nil,
|
||||
wantErr: nil,
|
||||
}, {
|
||||
name: "many",
|
||||
iface: &fakeIface{addrs: []net.Addr{addr4, addr4}},
|
||||
ipv: IPVersion4,
|
||||
want: []net.IP{ip4, ip4},
|
||||
wantErr: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, gotErr := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
|
||||
require.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
got, err := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
43
internal/aghnet/ipmut.go
Normal file
43
internal/aghnet/ipmut.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// IPMutFunc is the signature of a function which modifies the IP address
|
||||
// instance. It should be safe for concurrent use.
|
||||
type IPMutFunc func(ip net.IP)
|
||||
|
||||
// nopIPMutFunc is the IPMutFunc that does nothing.
|
||||
func nopIPMutFunc(net.IP) {}
|
||||
|
||||
// IPMut is a type-safe wrapper of atomic.Value to store the IPMutFunc.
|
||||
type IPMut struct {
|
||||
f atomic.Value
|
||||
}
|
||||
|
||||
// NewIPMut returns the new properly initialized *IPMut. The m is guaranteed to
|
||||
// always store non-nil IPMutFunc which is safe to call.
|
||||
func NewIPMut(f IPMutFunc) (m *IPMut) {
|
||||
m = &IPMut{
|
||||
f: atomic.Value{},
|
||||
}
|
||||
m.Store(f)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Store sets the IPMutFunc to return from Func. It's safe for concurrent use.
|
||||
// If f is nil, the stored function is the no-op one.
|
||||
func (m *IPMut) Store(f IPMutFunc) {
|
||||
if f == nil {
|
||||
f = nopIPMutFunc
|
||||
}
|
||||
m.f.Store(f)
|
||||
}
|
||||
|
||||
// Load returns the previously stored IPMutFunc.
|
||||
func (m *IPMut) Load() (f IPMutFunc) {
|
||||
return m.f.Load().(IPMutFunc)
|
||||
}
|
||||
44
internal/aghnet/ipmut_test.go
Normal file
44
internal/aghnet/ipmut_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIPMut(t *testing.T) {
|
||||
testIPs := []net.IP{{
|
||||
127, 0, 0, 1,
|
||||
}, {
|
||||
192, 168, 0, 1,
|
||||
}, {
|
||||
8, 8, 8, 8,
|
||||
}}
|
||||
|
||||
t.Run("nil_no_mut", func(t *testing.T) {
|
||||
ipmut := NewIPMut(nil)
|
||||
|
||||
ips := netutil.CloneIPs(testIPs)
|
||||
for i := range ips {
|
||||
ipmut.Load()(ips[i])
|
||||
assert.True(t, ips[i].Equal(testIPs[i]))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not_nil_mut", func(t *testing.T) {
|
||||
ipmut := NewIPMut(func(ip net.IP) {
|
||||
for i := range ip {
|
||||
ip[i] = 0
|
||||
}
|
||||
})
|
||||
want := netutil.IPv4Zero()
|
||||
|
||||
ips := netutil.CloneIPs(testIPs)
|
||||
for i := range ips {
|
||||
ipmut.Load()(ips[i])
|
||||
assert.True(t, ips[i].Equal(want))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -20,7 +20,12 @@ type IpsetManager interface {
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
//
|
||||
// The error is of type *aghos.UnsupportedError if the OS is not supported.
|
||||
// If ipsetConf is empty, msg and err are nil. The error is of type
|
||||
// *aghos.UnsupportedError if the OS is not supported.
|
||||
func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) {
|
||||
if len(ipsetConf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return newIpsetMgr(ipsetConf)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/ti-mo/netfilter"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// How to test on a real Linux machine:
|
||||
@@ -42,11 +43,17 @@ import (
|
||||
|
||||
// newIpsetMgr returns a new Linux ipset manager.
|
||||
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
|
||||
dial := func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
|
||||
return ipset.Dial(pf, conf)
|
||||
return newIpsetMgrWithDialer(ipsetConf, defaultDial)
|
||||
}
|
||||
|
||||
// defaultDial is the default netfilter dialing function.
|
||||
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
|
||||
conn, err = ipset.Dial(pf, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newIpsetMgrWithDialer(ipsetConf, dial)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ipsetConn is the ipset conn interface.
|
||||
@@ -103,8 +110,8 @@ func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) {
|
||||
// The kernel API does not actually require two sockets but package
|
||||
// github.com/digineo/go-ipset does.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and
|
||||
// just use packages netfilter and netlink.
|
||||
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and just
|
||||
// use packages netfilter and netlink.
|
||||
m.ipv4Conn, err = m.dial(netfilter.ProtoIPv4, conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing v4: %w", err)
|
||||
@@ -214,6 +221,14 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
|
||||
|
||||
err = m.dialNetfilter(&netlink.Config{})
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EPROTONOSUPPORT) {
|
||||
// The implementation doesn't support this protocol version. Just
|
||||
// issue a warning.
|
||||
log.Info("ipset: dialing netfilter: warning: %s", err)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("dialing netfilter: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@ package aghnet
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -44,8 +42,7 @@ func GatewayIP(ifaceName string) net.IP {
|
||||
|
||||
fields := strings.Fields(string(d))
|
||||
// The meaningful "ip route" command output should contain the word
|
||||
// "default" at first field and default gateway IP address at third
|
||||
// field.
|
||||
// "default" at first field and default gateway IP address at third field.
|
||||
if len(fields) < 3 || fields[0] != "default" {
|
||||
return nil
|
||||
}
|
||||
@@ -189,79 +186,35 @@ func GetSubnet(ifaceName string) *net.IPNet {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPortAvailable - check if TCP port is available
|
||||
func CheckPortAvailable(host net.IP, port int) error {
|
||||
ln, err := net.Listen("tcp", netutil.JoinHostPort(host.String(), port))
|
||||
// CheckPort checks if the port is available for binding. network is expected
|
||||
// to be one of "udp" and "tcp".
|
||||
func CheckPort(network string, ip net.IP, port int) (err error) {
|
||||
var c io.Closer
|
||||
addr := netutil.IPPort{IP: ip, Port: port}.String()
|
||||
switch network {
|
||||
case "tcp":
|
||||
c, err = net.Listen(network, addr)
|
||||
case "udp":
|
||||
c, err = net.ListenPacket(network, addr)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = ln.Close()
|
||||
|
||||
// It seems that net.Listener.Close() doesn't close file descriptors right away.
|
||||
// We wait for some time and hope that this fd will be closed.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
return closePortChecker(c)
|
||||
}
|
||||
|
||||
// CheckPacketPortAvailable - check if UDP port is available
|
||||
func CheckPacketPortAvailable(host net.IP, port int) error {
|
||||
ln, err := net.ListenPacket("udp", netutil.JoinHostPort(host.String(), port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = ln.Close()
|
||||
|
||||
// It seems that net.Listener.Close() doesn't close file descriptors right away.
|
||||
// We wait for some time and hope that this fd will be closed.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrorIsAddrInUse - check if error is "address already in use"
|
||||
func ErrorIsAddrInUse(err error) bool {
|
||||
errOpError, ok := err.(*net.OpError)
|
||||
if !ok {
|
||||
// IsAddrInUse checks if err is about unsuccessful address binding.
|
||||
func IsAddrInUse(err error) (ok bool) {
|
||||
var sysErr syscall.Errno
|
||||
if !errors.As(err, &sysErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
errSyscallError, ok := errOpError.Err.(*os.SyscallError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
errErrno, ok := errSyscallError.Err.(syscall.Errno)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
const WSAEADDRINUSE = 10048
|
||||
return errErrno == WSAEADDRINUSE
|
||||
}
|
||||
|
||||
return errErrno == syscall.EADDRINUSE
|
||||
}
|
||||
|
||||
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
|
||||
// does not necessarily contain a port.
|
||||
func SplitHost(hostport string) (host string, err error) {
|
||||
host, _, err = net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
// Check for the missing port error. If it is that error, just
|
||||
// use the host as is.
|
||||
//
|
||||
// See the source code for net.SplitHostPort.
|
||||
const missingPort = "missing port in address"
|
||||
|
||||
addrErr := &net.AddrError{}
|
||||
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host = hostport
|
||||
}
|
||||
|
||||
return host, nil
|
||||
return isAddrInUse(sysErr)
|
||||
}
|
||||
|
||||
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
|
||||
|
||||
@@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
const filename = "/etc/rc.conf"
|
||||
const rcConfFilename = "etc/rc.conf"
|
||||
|
||||
return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename)
|
||||
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
|
||||
|
||||
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
|
||||
}
|
||||
|
||||
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
|
||||
|
||||
@@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
|
||||
iface := interfaceName(ifaceName)
|
||||
|
||||
for _, pair := range []struct {
|
||||
for _, pair := range [...]struct {
|
||||
aghos.FileWalker
|
||||
filename string
|
||||
}{{
|
||||
FileWalker: iface.dhcpcdStaticConfig,
|
||||
filename: "/etc/dhcpcd.conf",
|
||||
filename: "etc/dhcpcd.conf",
|
||||
}, {
|
||||
FileWalker: iface.ifacesStaticConfig,
|
||||
filename: "/etc/network/interfaces",
|
||||
filename: "etc/network/interfaces",
|
||||
}} {
|
||||
has, err = pair.Walk(pair.filename)
|
||||
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const nl = "\n"
|
||||
|
||||
func TestDHCPCDStaticConfig(t *testing.T) {
|
||||
const iface interfaceName = `wlan0`
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) {
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)
|
||||
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
|
||||
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
|
||||
}
|
||||
|
||||
// hostnameIfStaticConfig checks if the interface is configured by
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
//go:build !(linux || darwin || freebsd || openbsd)
|
||||
// +build !linux,!darwin,!freebsd,!openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(string) (ok bool, err error) {
|
||||
return false, aghos.Unsupported("checking static ip")
|
||||
}
|
||||
|
||||
func ifaceSetStaticIP(string) (err error) {
|
||||
return aghos.Unsupported("setting static ip")
|
||||
}
|
||||
@@ -4,16 +4,31 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestGetInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoErrorf(t, err, "cannot get net interfaces: %s", err)
|
||||
require.NotEmpty(t, ifaces, "no net interfaces found")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
|
||||
for _, iface := range ifaces {
|
||||
require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name)
|
||||
t.Run(iface.Name, func(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := GetInterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,3 +79,49 @@ func TestBroadcastFromIPNet(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPort(t *testing.T) {
|
||||
t.Run("tcp_bound", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
ipp := netutil.IPPortFromAddr(l.Addr())
|
||||
require.NotNil(t, ipp)
|
||||
require.NotNil(t, ipp.IP)
|
||||
require.NotZero(t, ipp.Port)
|
||||
|
||||
err = CheckPort("tcp", ipp.IP, ipp.Port)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("udp_bound", func(t *testing.T) {
|
||||
conn, err := net.ListenPacket("udp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
||||
|
||||
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
|
||||
require.NotNil(t, ipp)
|
||||
require.NotNil(t, ipp.IP)
|
||||
require.NotZero(t, ipp.Port)
|
||||
|
||||
err = CheckPort("udp", ipp.IP, ipp.Port)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("bad_network", func(t *testing.T) {
|
||||
err := CheckPort("bad_network", nil, 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can_bind", func(t *testing.T) {
|
||||
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
//go:build openbsd || freebsd || linux
|
||||
// +build openbsd freebsd linux
|
||||
//go:build openbsd || freebsd || linux || darwin
|
||||
// +build openbsd freebsd linux darwin
|
||||
|
||||
package aghnet
|
||||
|
||||
// interfaceName is a string containing network interface's name. The name is
|
||||
// used in file walking methods.
|
||||
type interfaceName string
|
||||
import (
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// closePortChecker closes c. c must be non-nil.
|
||||
func closePortChecker(c io.Closer) (err error) {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func isAddrInUse(err syscall.Errno) (ok bool) {
|
||||
return errors.Is(err, syscall.EADDRINUSE)
|
||||
}
|
||||
|
||||
46
internal/aghnet/net_windows.go
Normal file
46
internal/aghnet/net_windows.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build !(linux || darwin || freebsd || openbsd)
|
||||
// +build !linux,!darwin,!freebsd,!openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(string) (ok bool, err error) {
|
||||
return false, aghos.Unsupported("checking static ip")
|
||||
}
|
||||
|
||||
func ifaceSetStaticIP(string) (err error) {
|
||||
return aghos.Unsupported("setting static ip")
|
||||
}
|
||||
|
||||
// closePortChecker closes c. c must be non-nil.
|
||||
func closePortChecker(c io.Closer) (err error) {
|
||||
if err = c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// It seems that net.Listener.Close() doesn't close file descriptors right
|
||||
// away. We wait for some time and hope that this fd will be closed.
|
||||
//
|
||||
// TODO(e.burkov): Investigate the purpose of the line and perhaps use more
|
||||
// reliable approach.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAddrInUse(err syscall.Errno) (ok bool) {
|
||||
return errors.Is(err, windows.WSAEADDRINUSE)
|
||||
}
|
||||
@@ -6,16 +6,18 @@ import (
|
||||
|
||||
// SubnetDetector describes IP address properties.
|
||||
type SubnetDetector struct {
|
||||
// spNets is the slice of special-purpose address registries as defined
|
||||
// by RFC-6890 (https://tools.ietf.org/html/rfc6890).
|
||||
// spNets is the collection of special-purpose address registries as defined
|
||||
// by RFC 6890.
|
||||
spNets []*net.IPNet
|
||||
|
||||
// locServedNets is the slice of locally-served networks as defined by
|
||||
// RFC-6303 (https://tools.ietf.org/html/rfc6303).
|
||||
// locServedNets is the collection of locally-served networks as defined by
|
||||
// RFC 6303.
|
||||
locServedNets []*net.IPNet
|
||||
}
|
||||
|
||||
// NewSubnetDetector returns a new IP detector.
|
||||
//
|
||||
// TODO(a.garipov): Decide whether an error is actually needed.
|
||||
func NewSubnetDetector() (snd *SubnetDetector, err error) {
|
||||
spNets := []string{
|
||||
// "This" network.
|
||||
|
||||
@@ -79,8 +79,8 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := imp.dialFunc(context.Background(), "", tc.address)
|
||||
|
||||
require.Nil(t, conn)
|
||||
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
|
||||
38
internal/aghnet/testdata/etc_hosts
vendored
Normal file
38
internal/aghnet/testdata/etc_hosts
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Test /etc/hosts file
|
||||
#
|
||||
|
||||
1.0.0.1 simplehost
|
||||
1.0.0.0 hello hello.world
|
||||
|
||||
# See https://github.com/AdguardTeam/AdGuardHome/issues/3846.
|
||||
1.0.0.2 a.whole lot.of aliases for.testing
|
||||
|
||||
# See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
|
||||
1.0.0.3 *
|
||||
1.0.0.4 *.com
|
||||
|
||||
# See https://github.com/AdguardTeam/AdGuardHome/issues/4079.
|
||||
1.0.0.0 hello.world.again
|
||||
|
||||
# Duplicates of a main host and an alias.
|
||||
1.0.0.1 simplehost
|
||||
1.0.0.0 hello.world
|
||||
|
||||
# Same for IPv6.
|
||||
::1 simplehost
|
||||
:: hello hello.world
|
||||
::2 a.whole lot.of aliases for.testing
|
||||
::3 *
|
||||
::4 *.com
|
||||
:: hello.world.again
|
||||
::1 simplehost
|
||||
:: hello.world
|
||||
|
||||
# See https://github.com/AdguardTeam/AdGuardHome/issues/4216.
|
||||
4.2.1.6 domain domain.alias
|
||||
::42 domain.alias domain
|
||||
1.3.5.7 domain4 domain4.alias
|
||||
7.5.3.1 domain4.alias domain4
|
||||
::13 domain6 domain6.alias
|
||||
::31 domain6.alias domain6
|
||||
@@ -3,10 +3,8 @@ package aghos
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"io/fs"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
@@ -14,27 +12,28 @@ import (
|
||||
// FileWalker is the signature of a function called for files in the file tree.
|
||||
// As opposed to filepath.Walk it only walk the files (not directories) matching
|
||||
// the provided pattern and those returned by function itself. All patterns
|
||||
// should be valid for filepath.Glob. If cont is false, the walking terminates.
|
||||
// Each opened file is also limited for reading to MaxWalkedFileSize.
|
||||
// should be valid for fs.Glob. If FileWalker returns false for cont then
|
||||
// walking terminates. Prefer using bufio.Scanner to read the r since the input
|
||||
// is not limited.
|
||||
//
|
||||
// TODO(e.burkov): Consider moving to the separate package like pathutil.
|
||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||
//
|
||||
// TODO(e.burkov): Think about passing filename or any additional data.
|
||||
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
|
||||
|
||||
// MaxWalkedFileSize is the maximum length of the file that FileWalker can
|
||||
// check.
|
||||
const MaxWalkedFileSize = 1024 * 1024
|
||||
|
||||
// checkFile tries to open and process a single file located on sourcePath.
|
||||
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) {
|
||||
var f *os.File
|
||||
f, err = os.Open(sourcePath)
|
||||
// checkFile tries to open and process a single file located on sourcePath in
|
||||
// the specified fsys. The path is skipped if it's a directory.
|
||||
func checkFile(
|
||||
fsys fs.FS,
|
||||
c FileWalker,
|
||||
sourcePath string,
|
||||
) (patterns []string, cont bool, err error) {
|
||||
var f fs.File
|
||||
f, err = fsys.Open(sourcePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Ignore non-existing files since this may only happen
|
||||
// when the file was removed after filepath.Glob matched
|
||||
// it.
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Ignore non-existing files since this may only happen when the
|
||||
// file was removed after filepath.Glob matched it.
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
@@ -42,23 +41,28 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
var r io.Reader
|
||||
// Ignore the error since LimitReader function returns error only if
|
||||
// passed limit value is less than zero, but the constant used.
|
||||
//
|
||||
// TODO(e.burkov): Make variable.
|
||||
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
|
||||
var fi fs.FileInfo
|
||||
if fi, err = f.Stat(); err != nil {
|
||||
return nil, true, err
|
||||
} else if fi.IsDir() {
|
||||
// Skip the directories.
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
return c(r)
|
||||
return c(f)
|
||||
}
|
||||
|
||||
// handlePatterns parses the patterns and ignores duplicates using srcSet.
|
||||
// srcSet must be non-nil.
|
||||
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) {
|
||||
// handlePatterns parses the patterns in fsys and ignores duplicates using
|
||||
// srcSet. srcSet must be non-nil.
|
||||
func handlePatterns(
|
||||
fsys fs.FS,
|
||||
srcSet *stringutil.Set,
|
||||
patterns ...string,
|
||||
) (sub []string, err error) {
|
||||
sub = make([]string, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
var matches []string
|
||||
matches, err = filepath.Glob(p)
|
||||
matches, err = fs.Glob(fsys, p)
|
||||
if err != nil {
|
||||
// Enrich error with the pattern because filepath.Glob
|
||||
// doesn't do it.
|
||||
@@ -78,14 +82,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Walk starts walking the files defined by initPattern. It only returns true
|
||||
// if c signed to stop walking.
|
||||
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
// The slice of sources keeps the order in which the files are walked
|
||||
// since srcSet.Values() returns strings in undefined order.
|
||||
// Walk starts walking the files in fsys defined by patterns from initial.
|
||||
// It only returns true if fw signed to stop walking.
|
||||
func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
|
||||
// The slice of sources keeps the order in which the files are walked since
|
||||
// srcSet.Values() returns strings in undefined order.
|
||||
srcSet := stringutil.NewSet()
|
||||
var src []string
|
||||
src, err = handlePatterns(srcSet, initPattern)
|
||||
src, err = handlePatterns(fsys, srcSet, initial...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -97,7 +101,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
var patterns []string
|
||||
var cont bool
|
||||
filename = src[i]
|
||||
patterns, cont, err = checkFile(c, src[i])
|
||||
patterns, cont, err = checkFile(fsys, fw, src[i])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -107,7 +111,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
|
||||
}
|
||||
|
||||
var subsrc []string
|
||||
subsrc, err = handlePatterns(srcSet, patterns...)
|
||||
subsrc, err = handlePatterns(fsys, srcSet, patterns...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -4,56 +4,19 @@ import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testFSDir maps entries' names to entries which should either be a testFSDir
|
||||
// or byte slice.
|
||||
type testFSDir map[string]interface{}
|
||||
|
||||
// testFSGen is used to generate a temporary filesystem consisting of
|
||||
// directories and plain text files from itself.
|
||||
type testFSGen testFSDir
|
||||
|
||||
// gen returns the name of top directory of the generated filesystem.
|
||||
func (g testFSGen) gen(t *testing.T) (dirName string) {
|
||||
t.Helper()
|
||||
|
||||
dirName = t.TempDir()
|
||||
g.rangeThrough(t, dirName)
|
||||
|
||||
return dirName
|
||||
}
|
||||
|
||||
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
|
||||
const perm fs.FileMode = 0o777
|
||||
|
||||
for k, e := range g {
|
||||
switch e := e.(type) {
|
||||
case []byte:
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
|
||||
|
||||
case testFSDir:
|
||||
newDir := filepath.Join(dirName, k)
|
||||
require.NoError(t, os.Mkdir(newDir, perm))
|
||||
|
||||
testFSGen(e).rangeThrough(t, newDir)
|
||||
default:
|
||||
t.Fatalf("unexpected entry type %T", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(dirName string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(line) != 0 {
|
||||
patterns = append(patterns, filepath.Join(dirName, line))
|
||||
patterns = append(patterns, path.Join(".", line))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
const nl = "\n"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFS testFSGen
|
||||
testFS fstest.MapFS
|
||||
want assert.BoolAssertionFunc
|
||||
initPattern string
|
||||
want bool
|
||||
name string
|
||||
}{{
|
||||
name: "simple",
|
||||
testFS: testFSGen{
|
||||
"simple_0001.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "simple_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "chain",
|
||||
testFS: testFSGen{
|
||||
"chain_0001.txt": []byte(`chain_0002.txt` + nl),
|
||||
"chain_0002.txt": []byte(`chain_0003.txt` + nl),
|
||||
"chain_0003.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)},
|
||||
"chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)},
|
||||
"chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "chain_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "several",
|
||||
testFS: testFSGen{
|
||||
"several_0001.txt": []byte(`several_*` + nl),
|
||||
"several_0002.txt": []byte(`several_0001.txt` + nl),
|
||||
"several_0003.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
"several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)},
|
||||
"several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)},
|
||||
"several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
|
||||
},
|
||||
initPattern: "several_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "no",
|
||||
testFS: testFSGen{
|
||||
"no_0001.txt": []byte(nl),
|
||||
"no_0002.txt": []byte(nl),
|
||||
"no_0003.txt": []byte(nl),
|
||||
testFS: fstest.MapFS{
|
||||
"no_0001.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
"no_0002.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
"no_0003.txt": &fstest.MapFile{Data: []byte(nl)},
|
||||
},
|
||||
initPattern: "no_*",
|
||||
want: false,
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "subdirectory",
|
||||
testFS: testFSGen{
|
||||
"dir": testFSDir{
|
||||
"subdir_0002.txt": []byte(attribute + nl),
|
||||
testFS: fstest.MapFS{
|
||||
path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{
|
||||
Data: []byte(attribute + nl),
|
||||
},
|
||||
"subdir_0001.txt": []byte(`dir/*`),
|
||||
"subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)},
|
||||
},
|
||||
initPattern: "subdir_0001.txt",
|
||||
want: true,
|
||||
want: assert.True,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
testDir := tc.testFS.gen(t)
|
||||
fw := makeFileWalker(testDir)
|
||||
fw := makeFileWalker("")
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern))
|
||||
ok, err := fw.Walk(tc.testFS, tc.initPattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, ok)
|
||||
tc.want(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("pattern_malformed", func(t *testing.T) {
|
||||
ok, err := makeFileWalker("").Walk("[]")
|
||||
f := fstest.MapFS{}
|
||||
ok, err := makeFileWalker("").Walk(f, "[]")
|
||||
require.Error(t, err)
|
||||
|
||||
assert.False(t, ok)
|
||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
||||
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||
})
|
||||
|
||||
t.Run("bad_filename", func(t *testing.T) {
|
||||
dir := testFSGen{
|
||||
"bad_filename.txt": []byte("[]"),
|
||||
}.gen(t)
|
||||
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
const filename = "bad_filename.txt"
|
||||
|
||||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
}
|
||||
|
||||
return patterns, true, s.Err()
|
||||
})
|
||||
|
||||
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
|
||||
}).Walk(f, filename)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.False(t, ok)
|
||||
assert.ErrorIs(t, err, filepath.ErrBadPattern)
|
||||
assert.ErrorIs(t, err, path.ErrBadPattern)
|
||||
})
|
||||
|
||||
t.Run("itself_error", func(t *testing.T) {
|
||||
const rerr errors.Error = "returned error"
|
||||
|
||||
dir := testFSGen{
|
||||
"mockfile.txt": []byte(`mockdata`),
|
||||
}.gen(t)
|
||||
f := fstest.MapFS{
|
||||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(filepath.Join(dir, "*"))
|
||||
require.Error(t, err)
|
||||
require.False(t, ok)
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
|
||||
assert.ErrorIs(t, err, rerr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(nil, "lol")
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
const badPath = "\x00"
|
||||
|
||||
_, ok, err := checkFile(nil, badPath)
|
||||
require.Error(t, err)
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
// TODO(e.burkov): Use assert.ErrorsIs within the error from
|
||||
// less platform-dependent package instead of syscall.EINVAL.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/46849 and
|
||||
// https://github.com/golang/go/issues/30322.
|
||||
pathErr := &os.PathError{}
|
||||
require.ErrorAs(t, err, &pathErr)
|
||||
assert.Equal(t, "open", pathErr.Op)
|
||||
assert.Equal(t, badPath, pathErr.Path)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
133
internal/aghos/fswatcher.go
Normal file
133
internal/aghos/fswatcher.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// event is a convenient alias for an empty struct to signal that watching
|
||||
// event happened.
|
||||
type event = struct{}
|
||||
|
||||
// FSWatcher tracks all the fyle system events and notifies about those.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||
type FSWatcher interface {
|
||||
io.Closer
|
||||
|
||||
// Events should return a read-only channel which notifies about events.
|
||||
Events() (e <-chan event)
|
||||
|
||||
// Add should check if the file named name is accessible and starts tracking
|
||||
// it.
|
||||
Add(name string) (err error)
|
||||
}
|
||||
|
||||
// osWatcher tracks the file system provided by the OS.
|
||||
type osWatcher struct {
|
||||
// w is the actual notifier that is handled by osWatcher.
|
||||
w *fsnotify.Watcher
|
||||
|
||||
// events is the channel to notify.
|
||||
events chan event
|
||||
}
|
||||
|
||||
const (
|
||||
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||
// methods.
|
||||
osWatcherPref = "os watcher"
|
||||
)
|
||||
|
||||
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
|
||||
// OS and notifies only about writing events.
|
||||
func NewOSWritesWatcher() (w FSWatcher, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
var watcher *fsnotify.Watcher
|
||||
watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating watcher: %w", err)
|
||||
}
|
||||
|
||||
fsw := &osWatcher{
|
||||
w: watcher,
|
||||
events: make(chan event, 1),
|
||||
}
|
||||
|
||||
go fsw.handleErrors()
|
||||
go fsw.handleEvents()
|
||||
|
||||
return fsw, nil
|
||||
}
|
||||
|
||||
// handleErrors handles accompanying errors. It used to be called in a separate
|
||||
// goroutine.
|
||||
func (w *osWatcher) handleErrors() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
|
||||
|
||||
for err := range w.w.Errors {
|
||||
log.Error("%s: %s", osWatcherPref, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Events implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Events() (e <-chan event) {
|
||||
return w.events
|
||||
}
|
||||
|
||||
// Add implements the FSWatcher interface for *osWatcher.
|
||||
//
|
||||
// TODO(e.burkov): Make it accept non-existing files to detect it's creating.
|
||||
func (w *osWatcher) Add(name string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
if _, err = fs.Stat(RootDirFS(), name); err != nil {
|
||||
return fmt.Errorf("checking file %q: %w", name, err)
|
||||
}
|
||||
|
||||
return w.w.Add(filepath.Join("/", name))
|
||||
}
|
||||
|
||||
// Close implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Close() (err error) {
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
// handleEvents notifies about the received file system's event if needed. It
|
||||
// used to be called in a separate goroutine.
|
||||
func (w *osWatcher) handleEvents() {
|
||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
|
||||
|
||||
defer close(w.events)
|
||||
|
||||
ch := w.w.Events
|
||||
for e := range ch {
|
||||
if e.Op&fsnotify.Write == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the following events assuming that sometimes the same event
|
||||
// occurrs several times.
|
||||
for ok := true; ok; {
|
||||
select {
|
||||
case _, ok = <-ch:
|
||||
// Go on.
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case w.events <- event{}:
|
||||
// Go on.
|
||||
default:
|
||||
log.Debug("%s: events buffer is full", osWatcherPref)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime"
|
||||
@@ -89,7 +91,7 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
|
||||
}
|
||||
|
||||
var instNum int
|
||||
pid, instNum, err = parsePSOutput(stdout, command, except...)
|
||||
pid, instNum, err = parsePSOutput(stdout, command, except)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -116,16 +118,15 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
|
||||
}
|
||||
|
||||
// parsePSOutput scans the output of ps searching the largest PID of the process
|
||||
// associated with cmdName ignoring PIDs from ignore. Valid r's line shoud be
|
||||
// like:
|
||||
// associated with cmdName ignoring PIDs from ignore. A valid line from
|
||||
// r should look like these:
|
||||
//
|
||||
// 123 ./example-cmd
|
||||
// 1230 some/base/path/example-cmd
|
||||
// 3210 example-cmd
|
||||
//
|
||||
func parsePSOutput(r io.Reader, cmdName string, ignore ...int) (largest, instNum int, err error) {
|
||||
func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum int, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
ScanLoop:
|
||||
for s.Scan() {
|
||||
fields := strings.Fields(s.Text())
|
||||
if len(fields) != 2 || path.Base(fields[1]) != cmdName {
|
||||
@@ -133,16 +134,10 @@ ScanLoop:
|
||||
}
|
||||
|
||||
cur, aerr := strconv.Atoi(fields[0])
|
||||
if aerr != nil || cur < 0 {
|
||||
if aerr != nil || cur < 0 || intIn(cur, ignore) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pid := range ignore {
|
||||
if cur == pid {
|
||||
continue ScanLoop
|
||||
}
|
||||
}
|
||||
|
||||
instNum++
|
||||
if cur > largest {
|
||||
largest = cur
|
||||
@@ -155,7 +150,25 @@ ScanLoop:
|
||||
return largest, instNum, nil
|
||||
}
|
||||
|
||||
// intIn returns true if nums contains n.
|
||||
func intIn(n int, nums []int) (ok bool) {
|
||||
for _, nn := range nums {
|
||||
if n == nn {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenWrt returns true if host OS is OpenWrt.
|
||||
func IsOpenWrt() (ok bool) {
|
||||
return isOpenWrt()
|
||||
}
|
||||
|
||||
// RootDirFS returns the fs.FS rooted at the operating system's root.
|
||||
func RootDirFS() (fsys fs.FS) {
|
||||
// Use empty string since os.DirFS implicitly prepends a slash to it. This
|
||||
// behavior is undocumented but it currently works.
|
||||
return os.DirFS("")
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) {
|
||||
}
|
||||
|
||||
func isOpenWrt() (ok bool) {
|
||||
const etcReleasePattern = "etc/*release*"
|
||||
|
||||
var err error
|
||||
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
|
||||
const osNameData = "openwrt"
|
||||
@@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) {
|
||||
}
|
||||
|
||||
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
||||
}).Walk("/etc/*release*")
|
||||
}).Walk(RootDirFS(), etcReleasePattern)
|
||||
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestLargestLabeled(t *testing.T) {
|
||||
r := bytes.NewReader(tc.data)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pid, instNum, err := parsePSOutput(r, comm)
|
||||
pid, instNum, err := parsePSOutput(r, comm, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantPID, pid)
|
||||
@@ -76,7 +76,7 @@ func TestLargestLabeled(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
target := &aghio.LimitReachedError{}
|
||||
_, _, err = parsePSOutput(lr, "")
|
||||
_, _, err = parsePSOutput(lr, "", nil)
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.EqualValues(t, 0, target.Limit)
|
||||
@@ -89,7 +89,7 @@ func TestLargestLabeled(t *testing.T) {
|
||||
`3` + comm + nl,
|
||||
))
|
||||
|
||||
pid, instances, err := parsePSOutput(r, comm, 1, 3)
|
||||
pid, instances, err := parsePSOutput(r, comm, []int{1, 3})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, pid)
|
||||
|
||||
@@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) {
|
||||
|
||||
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
|
||||
// revert changes.
|
||||
func ReplaceLogWriter(t *testing.T, w io.Writer) {
|
||||
stdWriter := log.Writer()
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(stdWriter)
|
||||
})
|
||||
func ReplaceLogWriter(t testing.TB, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
prev := log.Writer()
|
||||
t.Cleanup(func() { log.SetOutput(prev) })
|
||||
log.SetOutput(w)
|
||||
}
|
||||
|
||||
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
|
||||
// revert changes.
|
||||
func ReplaceLogLevel(t *testing.T, l log.Level) {
|
||||
func ReplaceLogLevel(t testing.TB, l log.Level) {
|
||||
t.Helper()
|
||||
|
||||
switch l {
|
||||
case log.INFO, log.DEBUG, log.ERROR:
|
||||
// Go on.
|
||||
@@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) {
|
||||
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
|
||||
}
|
||||
|
||||
stdLevel := log.GetLevel()
|
||||
t.Cleanup(func() {
|
||||
log.SetLevel(stdLevel)
|
||||
})
|
||||
prev := log.GetLevel()
|
||||
t.Cleanup(func() { log.SetLevel(prev) })
|
||||
log.SetLevel(l)
|
||||
}
|
||||
|
||||
23
internal/aghtest/fswatcher.go
Normal file
23
internal/aghtest/fswatcher.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
46
internal/aghtest/testfs.go
Normal file
46
internal/aghtest/testfs.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package aghtest
|
||||
|
||||
import "io/fs"
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock fs.FS implementation to use in tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the fs.FS interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock fs.StatFS implementation to use in tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the fs.StatFS interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock fs.GlobFS implementation to use in tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the fs.GlobFS interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TestUpstream is a mock of real upstream.
|
||||
type TestUpstream struct {
|
||||
// Upstream is a mock implementation of upstream.Upstream.
|
||||
type Upstream struct {
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string]string
|
||||
CName map[string][]string
|
||||
// IPv4 is a map of hostname to IPv4.
|
||||
IPv4 map[string][]net.IP
|
||||
// IPv6 is a map of hostname to IPv6.
|
||||
@@ -25,78 +25,45 @@ type TestUpstream struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
// Exchange implements upstream.Upstream interface for *TestUpstream.
|
||||
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Split further into handlers.
|
||||
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetReply(m)
|
||||
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = new(dns.Msg).SetReply(m)
|
||||
|
||||
if len(m.Question) == 0 {
|
||||
return nil, fmt.Errorf("question should not be empty")
|
||||
}
|
||||
|
||||
name := m.Question[0].Name
|
||||
|
||||
if cname, ok := u.CName[name]; ok {
|
||||
ans := &dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
q := m.Question[0]
|
||||
name := q.Name
|
||||
for _, cname := range u.CName[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||||
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME},
|
||||
Target: cname,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
})
|
||||
}
|
||||
|
||||
rrType := m.Question[0].Qtype
|
||||
qtype := q.Qtype
|
||||
hdr := dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: rrType,
|
||||
Rrtype: qtype,
|
||||
}
|
||||
|
||||
var names []string
|
||||
var ips []net.IP
|
||||
switch m.Question[0].Qtype {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
ips = u.IPv4[name]
|
||||
for _, ip := range u.IPv4[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip})
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
ips = u.IPv6[name]
|
||||
for _, ip := range u.IPv6[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
|
||||
}
|
||||
case dns.TypePTR:
|
||||
names = u.Reverse[name]
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
var ans dns.RR
|
||||
if rrType == dns.TypeA {
|
||||
ans = &dns.A{
|
||||
Hdr: hdr,
|
||||
A: ip,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
|
||||
continue
|
||||
for _, name := range u.Reverse[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
|
||||
}
|
||||
|
||||
ans = &dns.AAAA{
|
||||
Hdr: hdr,
|
||||
AAAA: ip,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, n := range names {
|
||||
ans := &dns.PTR{
|
||||
Hdr: hdr,
|
||||
Ptr: n,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
if len(resp.Answer) == 0 {
|
||||
resp.SetRcode(m, dns.RcodeNameError)
|
||||
}
|
||||
@@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Address implements upstream.Upstream interface for *TestUpstream.
|
||||
func (u *TestUpstream) Address() string {
|
||||
// Address implements upstream.Upstream interface for *Upstream.
|
||||
func (u *Upstream) Address() string {
|
||||
return u.Addr
|
||||
}
|
||||
|
||||
|
||||
30
internal/aghtls/aghtls.go
Normal file
30
internal/aghtls/aghtls.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Package aghtls contains utilities for work with TLS.
|
||||
package aghtls
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// SaferCipherSuites returns a set of default cipher suites with vulnerable and
|
||||
// weak cipher suites removed.
|
||||
func SaferCipherSuites() (safe []uint16) {
|
||||
for _, s := range tls.CipherSuites() {
|
||||
switch s.ID {
|
||||
case
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
|
||||
// Less safe 3DES and CBC suites, go on.
|
||||
default:
|
||||
safe = append(safe, s.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return safe
|
||||
}
|
||||
@@ -5,9 +5,9 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
@@ -45,7 +45,7 @@ func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
n, err := conn.WriteTo(nil, &unexpectedAddrType{})
|
||||
require.Error(t, err)
|
||||
|
||||
assert.True(t, strings.Contains(err.Error(), "peer is of unexpected type"))
|
||||
testutil.AssertErrorMsg(t, "peer is of unexpected type *dhcpd.unexpectedAddrType", err)
|
||||
assert.Zero(t, n)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -119,23 +119,28 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServerConfig - DHCP server configuration
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"` // path to DB file
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
// LocalDomainName is the domain name used for DHCP hosts. For example,
|
||||
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
|
||||
// when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"`
|
||||
}
|
||||
|
||||
// OnLeaseChangedT is a callback for lease changes.
|
||||
@@ -156,7 +161,9 @@ type Server struct {
|
||||
srv4 DHCPServer
|
||||
srv6 DHCPServer
|
||||
|
||||
conf ServerConfig
|
||||
// TODO(a.garipov): Either create a separate type for the internal config or
|
||||
// just put the config values into Server.
|
||||
conf *ServerConfig
|
||||
|
||||
// Called when the leases DB is modified
|
||||
onLeaseChanged []OnLeaseChangedT
|
||||
@@ -181,14 +188,21 @@ type ServerInterface interface {
|
||||
}
|
||||
|
||||
// Create - create object
|
||||
func Create(conf ServerConfig) (s *Server, err error) {
|
||||
s = &Server{}
|
||||
func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
s = &Server{
|
||||
conf: &ServerConfig{
|
||||
ConfigModified: conf.ConfigModified,
|
||||
|
||||
s.conf.Enabled = conf.Enabled
|
||||
s.conf.InterfaceName = conf.InterfaceName
|
||||
s.conf.HTTPRegister = conf.HTTPRegister
|
||||
s.conf.ConfigModified = conf.ConfigModified
|
||||
s.conf.DBFilePath = filepath.Join(conf.WorkDir, dbFilename)
|
||||
HTTPRegister: conf.HTTPRegister,
|
||||
|
||||
Enabled: conf.Enabled,
|
||||
InterfaceName: conf.InterfaceName,
|
||||
|
||||
LocalDomainName: conf.LocalDomainName,
|
||||
|
||||
DBFilePath: filepath.Join(conf.WorkDir, dbFilename),
|
||||
},
|
||||
}
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -305,6 +319,7 @@ func (s *Server) notify(flags int) {
|
||||
func (s *Server) WriteDiskConfig(c *ServerConfig) {
|
||||
c.Enabled = s.conf.Enabled
|
||||
c.InterfaceName = s.conf.InterfaceName
|
||||
c.LocalDomainName = s.conf.LocalDomainName
|
||||
s.srv4.WriteDiskConfig4(&c.Conf4)
|
||||
s.srv6.WriteDiskConfig6(&c.Conf6)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -26,7 +27,7 @@ func testNotify(flags uint32) {
|
||||
func TestDB(t *testing.T) {
|
||||
var err error
|
||||
s := Server{
|
||||
conf: ServerConfig{
|
||||
conf: &ServerConfig{
|
||||
DBFilePath: dbFilename,
|
||||
},
|
||||
}
|
||||
@@ -67,9 +68,7 @@ func TestDB(t *testing.T) {
|
||||
err = s.dbStore()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, os.Remove(dbFilename))
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(dbFilename) })
|
||||
|
||||
err = s.srv4.ResetLeases(nil)
|
||||
require.NoError(t, err)
|
||||
@@ -138,6 +137,49 @@ func TestNormalizeLeases(t *testing.T) {
|
||||
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
|
||||
}
|
||||
|
||||
func TestV4Server_badRange(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
gatewayIP net.IP
|
||||
subnetMask net.IP
|
||||
}{{
|
||||
name: "gateway_in_range",
|
||||
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
|
||||
"192.168.10.20-192.168.10.200",
|
||||
gatewayIP: net.IP{192, 168, 10, 120},
|
||||
subnetMask: net.IP{255, 255, 255, 0},
|
||||
}, {
|
||||
name: "outside_range_start",
|
||||
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
|
||||
"192.168.10.1/28",
|
||||
gatewayIP: net.IP{192, 168, 10, 1},
|
||||
subnetMask: net.IP{255, 255, 255, 240},
|
||||
}, {
|
||||
name: "outside_range_end",
|
||||
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
|
||||
"192.168.10.1/27",
|
||||
gatewayIP: net.IP{192, 168, 10, 1},
|
||||
subnetMask: net.IP{255, 255, 255, 224},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 20},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: tc.gatewayIP,
|
||||
SubnetMask: tc.subnetMask,
|
||||
notify: testNotify,
|
||||
}
|
||||
|
||||
_, err := v4Create(conf)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// cloneUDPAddr returns a deep copy of a.
|
||||
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
|
||||
return &net.UDPAddr{
|
||||
|
||||
@@ -8,18 +8,15 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("DHCP: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
type v4ServerConfJSON struct {
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
SubnetMask net.IP `json:"subnet_mask"`
|
||||
@@ -85,8 +82,13 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
|
||||
return
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal DHCP status json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,36 +211,34 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"failed to parse new dhcp config json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"dhcpv4 or dhcpv6 configuration must be complete")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -263,7 +263,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = s.dbLoad()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -272,9 +272,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var code int
|
||||
code, err = s.enableDHCP(conf.InterfaceName)
|
||||
if err != nil {
|
||||
httpError(r, w, code, "enabling dhcp: %s", err)
|
||||
|
||||
return
|
||||
aghhttp.Error(r, w, code, "enabling dhcp: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,7 +291,8 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -310,7 +309,15 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
var addrs []net.Addr
|
||||
addrs, err = iface.Addrs()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to get addresses for interface %s: %s",
|
||||
iface.Name,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -327,7 +334,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// not an IPNet, should not happen
|
||||
httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"got iface.Addrs() element %[1]s that is not net.IPNet, it is %[1]T",
|
||||
addr)
|
||||
|
||||
return
|
||||
}
|
||||
// ignore link-local
|
||||
@@ -348,8 +361,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
|
||||
return
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to marshal json with available interfaces: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,9 +471,13 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
|
||||
|
||||
return
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to marshal DHCP found json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,13 +485,13 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -481,7 +503,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
|
||||
err = s.srv6.AddStaticLease(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -490,7 +512,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
l.IP = ip4
|
||||
err = s.srv4.AddStaticLease(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -500,13 +522,13 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -518,7 +540,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
err = s.srv6.RemoveStaticLease(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -527,16 +549,23 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
l.IP = ip4
|
||||
err = s.srv4.RemoveStaticLease(l)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
|
||||
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
|
||||
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
|
||||
DefaultDHCPTimeoutICMP = 1000
|
||||
)
|
||||
|
||||
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -546,20 +575,28 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
log.Error("dhcp: removing db: %s", err)
|
||||
}
|
||||
|
||||
oldconf := s.conf
|
||||
s.conf = ServerConfig{}
|
||||
s.conf.WorkDir = oldconf.WorkDir
|
||||
s.conf.HTTPRegister = oldconf.HTTPRegister
|
||||
s.conf.ConfigModified = oldconf.ConfigModified
|
||||
s.conf.DBFilePath = oldconf.DBFilePath
|
||||
s.conf = &ServerConfig{
|
||||
ConfigModified: s.conf.ConfigModified,
|
||||
|
||||
v4conf := V4ServerConf{}
|
||||
v4conf.ICMPTimeout = 1000
|
||||
v4conf.notify = s.onNotify
|
||||
HTTPRegister: s.conf.HTTPRegister,
|
||||
|
||||
LocalDomainName: s.conf.LocalDomainName,
|
||||
|
||||
WorkDir: s.conf.WorkDir,
|
||||
DBFilePath: s.conf.DBFilePath,
|
||||
}
|
||||
|
||||
v4conf := V4ServerConf{
|
||||
LeaseDuration: DefaultDHCPLeaseTTL,
|
||||
ICMPTimeout: DefaultDHCPTimeoutICMP,
|
||||
notify: s.onNotify,
|
||||
}
|
||||
s.srv4, _ = v4Create(v4conf)
|
||||
|
||||
v6conf := V6ServerConf{}
|
||||
v6conf.notify = s.onNotify
|
||||
v6conf := V6ServerConf{
|
||||
LeaseDuration: DefaultDHCPLeaseTTL,
|
||||
notify: s.onNotify,
|
||||
}
|
||||
s.srv6, _ = v6Create(v6conf)
|
||||
|
||||
s.conf.ConfigModified()
|
||||
@@ -569,7 +606,7 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.resetLeases()
|
||||
if err != nil {
|
||||
msg := "resetting leases: %s"
|
||||
httpError(r, w, http.StatusInternalServerError, msg, err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, msg, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
h(w, r)
|
||||
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
//
|
||||
// It is safe for concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps create an optimised version with uint32 for
|
||||
// IPv4 ranges? Or use one of uint128 packages?
|
||||
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
|
||||
// ranges? Or use one of uint128 packages?
|
||||
type ipRange struct {
|
||||
start *big.Int
|
||||
end *big.Int
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -64,14 +65,8 @@ func TestNewIPRange(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r, err := newIPRange(tc.start, tc.end)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
_, err := newIPRange(tc.start, tc.end)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -11,33 +12,33 @@ import (
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want nullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
data: []byte{},
|
||||
wantErrMsg: "",
|
||||
data: []byte{},
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "null",
|
||||
data: []byte("null"),
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "true",
|
||||
data: []byte("true"),
|
||||
wantErrMsg: "",
|
||||
data: []byte("true"),
|
||||
want: nbTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
data: []byte("false"),
|
||||
wantErrMsg: "",
|
||||
data: []byte("false"),
|
||||
want: nbFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
data: []byte("flase"),
|
||||
wantErrMsg: `invalid nullBool value "flase"`,
|
||||
wantErrMsg: `invalid nullBool value "invalid"`,
|
||||
data: []byte("invalid"),
|
||||
want: nbNull,
|
||||
}}
|
||||
|
||||
@@ -45,13 +46,7 @@ func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got nullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
|
||||
@@ -130,21 +130,19 @@ func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
|
||||
// prepareOptions builds the set of DHCP options according to host requirements
|
||||
// document and values from conf.
|
||||
func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
|
||||
// Set default values for host configuration parameters listed in Appendix
|
||||
// A of RFC-2131. Those parameters, if requested by client, should be
|
||||
// returned with values defined by Host Requirements Document.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
|
||||
//
|
||||
// See also https://datatracker.ietf.org/doc/html/rfc1122,
|
||||
// https://datatracker.ietf.org/doc/html/rfc1123, and
|
||||
// https://datatracker.ietf.org/doc/html/rfc2132.
|
||||
opts = dhcpv4.Options{
|
||||
// Set default values for host configuration parameters listed
|
||||
// in Appendix A of RFC-2131. Those parameters, if requested by
|
||||
// client, should be returned with values defined by Host
|
||||
// Requirements Document.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
|
||||
//
|
||||
// See also https://datatracker.ietf.org/doc/html/rfc1122,
|
||||
// https://datatracker.ietf.org/doc/html/rfc1123, and
|
||||
// https://datatracker.ietf.org/doc/html/rfc2132.
|
||||
|
||||
// IP-Layer Per Host
|
||||
|
||||
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
|
||||
|
||||
// Set the current recommended default time to live for the
|
||||
// Internet Protocol which is 64, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc1700.
|
||||
|
||||
@@ -95,6 +95,7 @@ func TestParseOpt(t *testing.T) {
|
||||
opt, err := parseDHCPOption(tc.in)
|
||||
if tc.wantErrMsg != "" {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
|
||||
return
|
||||
|
||||
@@ -27,7 +27,6 @@ type DHCPServer interface {
|
||||
Start() (err error)
|
||||
// Stop - stop server
|
||||
Stop() (err error)
|
||||
|
||||
getLeasesRef() []*Lease
|
||||
}
|
||||
|
||||
|
||||
@@ -293,6 +293,8 @@ func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
offset, inOffset := r.offset(l.IP)
|
||||
|
||||
if l.IsStatic() {
|
||||
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
|
||||
// disabled.
|
||||
if sn := s.conf.subnet; !sn.Contains(l.IP) {
|
||||
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
|
||||
}
|
||||
@@ -900,9 +902,10 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
|
||||
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
|
||||
}
|
||||
}
|
||||
// Update the value of Domain Name Server option separately from others
|
||||
// since its value is set after server's creating.
|
||||
if requested.Has(dhcpv4.OptionDomainNameServer) {
|
||||
// Update the value of Domain Name Server option separately from others if
|
||||
// not assigned yet since its value is set after server's creating.
|
||||
if requested.Has(dhcpv4.OptionDomainNameServer) &&
|
||||
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
|
||||
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
}
|
||||
|
||||
@@ -966,11 +969,10 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
if mtype == dhcpv4.MessageTypeNak {
|
||||
// Set the broadcast bit in the DHCPNAK, so that the
|
||||
// relay agent broadcasted it to the client, because the
|
||||
// client may not have a correct network address or
|
||||
// subnet mask, and the client may not be answering ARP
|
||||
// requests.
|
||||
// Set the broadcast bit in the DHCPNAK, so that the relay agent
|
||||
// broadcasts it to the client, because the client may not have
|
||||
// a correct network address or subnet mask, and the client may not
|
||||
// be answering ARP requests.
|
||||
resp.SetBroadcast()
|
||||
}
|
||||
case mtype == dhcpv4.MessageTypeNak:
|
||||
@@ -1053,8 +1055,6 @@ func (s *v4Server) Start() (err error) {
|
||||
go func() {
|
||||
if serr := s.srv.Serve(); errors.Is(serr, net.ErrClosed) {
|
||||
log.Info("dhcpv4: server is closed")
|
||||
|
||||
return
|
||||
} else if serr != nil {
|
||||
log.Error("dhcpv4: srv.Serve: %s", serr)
|
||||
}
|
||||
@@ -1124,6 +1124,29 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
|
||||
if s.conf.ipRange.contains(routerIP) {
|
||||
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
|
||||
routerIP,
|
||||
conf.RangeStart,
|
||||
conf.RangeEnd,
|
||||
)
|
||||
}
|
||||
|
||||
if !s.conf.subnet.Contains(conf.RangeStart) {
|
||||
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
|
||||
conf.RangeStart,
|
||||
s.conf.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
if !s.conf.subnet.Contains(conf.RangeEnd) {
|
||||
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
|
||||
conf.RangeEnd,
|
||||
s.conf.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
|
||||
s.leasedOffsets = newBitSet()
|
||||
|
||||
if conf.LeaseDuration == 0 {
|
||||
|
||||
@@ -5,8 +5,11 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -16,17 +19,34 @@ import (
|
||||
func notify4(flags uint32) {
|
||||
}
|
||||
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
s, err := v4Create(V4ServerConf{
|
||||
// defaultV4ServerConf returns the default configuration for *v4Server to use in
|
||||
// tests.
|
||||
func defaultV4ServerConf() (conf V4ServerConf) {
|
||||
return V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
|
||||
// type of s is *v4Server.
|
||||
func defaultSrv(t *testing.T) (s DHCPServer) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
s, err = v4Create(defaultV4ServerConf())
|
||||
require.NoError(t, err)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
s := defaultSrv(t)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
|
||||
@@ -37,7 +57,7 @@ func TestV4_AddRemove_static(t *testing.T) {
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}
|
||||
|
||||
err = s.AddStaticLease(l)
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.AddStaticLease(l)
|
||||
@@ -65,15 +85,7 @@ func TestV4_AddRemove_static(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestV4_AddReplace(t *testing.T) {
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
sIface := defaultSrv(t)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
@@ -89,7 +101,7 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
}}
|
||||
|
||||
for i := range dynLeases {
|
||||
err = s.addLease(&dynLeases[i])
|
||||
err := s.addLease(&dynLeases[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -104,7 +116,7 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, l := range stLeases {
|
||||
err = s.AddStaticLease(l)
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -118,17 +130,80 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestV4StaticLease_Get(t *testing.T) {
|
||||
var err error
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
func TestV4Server_Process_optionsPriority(t *testing.T) {
|
||||
defaultIP := net.IP{192, 168, 1, 1}
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
// prepareSrv creates a *v4Server and sets the opt6IPs in the initial
|
||||
// configuration of the server as the value for DHCP option 6.
|
||||
prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
|
||||
t.Helper()
|
||||
|
||||
conf := defaultV4ServerConf()
|
||||
if len(opt6IPs) > 0 {
|
||||
b := &strings.Builder{}
|
||||
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
|
||||
for _, ip := range opt6IPs[1:] {
|
||||
stringutil.WriteToBuilder(b, ",", ip.String())
|
||||
}
|
||||
conf.Options = []string{b.String()}
|
||||
}
|
||||
|
||||
ss, err := v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
var ok bool
|
||||
s, ok = ss.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{defaultIP}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// checkResp creates a discovery message with DHCP option 6 requested amd
|
||||
// asserts the response to contain wantIPs in this option.
|
||||
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
|
||||
t.Helper()
|
||||
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
|
||||
dhcpv4.OptionDomainNameServer,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp *dhcpv4.DHCPv4
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s.process(req, resp)
|
||||
require.Equal(t, 1, res)
|
||||
|
||||
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
|
||||
require.NotEmpty(t, o)
|
||||
|
||||
wantData := []byte{}
|
||||
for _, ip := range wantIPs {
|
||||
wantData = append(wantData, ip...)
|
||||
}
|
||||
assert.Equal(t, o, wantData)
|
||||
}
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
s := prepareSrv(t, nil)
|
||||
|
||||
checkResp(t, s, []net.IP{defaultIP})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("explicitly_configured", func(t *testing.T) {
|
||||
s := prepareSrv(t, []net.IP{knownIP, knownIP})
|
||||
|
||||
checkResp(t, s, []net.IP{knownIP, knownIP})
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4StaticLease_Get(t *testing.T) {
|
||||
sIface := defaultSrv(t)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
@@ -140,7 +215,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}
|
||||
err = s.AddStaticLease(l)
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
@@ -208,19 +283,14 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestV4DynamicLease_Get(t *testing.T) {
|
||||
conf := defaultV4ServerConf()
|
||||
conf.Options = []string{
|
||||
"81 hex 303132",
|
||||
"82 ip 1.2.3.4",
|
||||
}
|
||||
|
||||
var err error
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
Options: []string{
|
||||
"81 hex 303132",
|
||||
"82 ip 1.2.3.4",
|
||||
},
|
||||
})
|
||||
sIface, err := v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
@@ -361,14 +431,7 @@ func TestNormalizeHostname(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := normalizeHostname(tc.hostname)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -662,8 +662,11 @@ func (s *v6Server) Start() (err error) {
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = s.srv.Serve()
|
||||
log.Error("dhcpv6: srv.Serve: %s", err)
|
||||
if serr := s.srv.Serve(); errors.Is(serr, net.ErrClosed) {
|
||||
log.Info("dhcpv6: server is closed")
|
||||
} else if serr != nil {
|
||||
log.Error("dhcpv6: srv.Serve: %s", serr)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -118,8 +119,8 @@ func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without client IDs
|
||||
// blocked by default.
|
||||
// In allowlist mode, consider requests without ClientIDs blocked by
|
||||
// default.
|
||||
return allowlistMode
|
||||
}
|
||||
|
||||
@@ -187,78 +188,62 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func isUniq(slice []string) (ok bool, uniqueMap map[string]unit) {
|
||||
exists := make(map[string]unit)
|
||||
for _, key := range slice {
|
||||
if _, has := exists[key]; has {
|
||||
return false, nil
|
||||
}
|
||||
exists[key] = unit{}
|
||||
}
|
||||
return true, exists
|
||||
}
|
||||
|
||||
func intersect(mapA, mapB map[string]unit) bool {
|
||||
for key := range mapA {
|
||||
if _, has := mapB[key]; has {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validateAccessSet checks the internal accessListJSON lists. To search for
|
||||
// duplicates, we cannot compare the new stringutil.Set and []string, because
|
||||
// creating a set for a large array can be an unnecessary algorithmic complexity
|
||||
func validateAccessSet(list accessListJSON) (err error) {
|
||||
const (
|
||||
errAllowedDup errors.Error = "duplicates in allowed clients"
|
||||
errDisallowedDup errors.Error = "duplicates in disallowed clients"
|
||||
errBlockedDup errors.Error = "duplicates in blocked hosts"
|
||||
errIntersect errors.Error = "some items in allowed and " +
|
||||
"disallowed lists at the same time"
|
||||
)
|
||||
|
||||
ok, allowedClients := isUniq(list.AllowedClients)
|
||||
if !ok {
|
||||
return errAllowedDup
|
||||
func validateAccessSet(list *accessListJSON) (err error) {
|
||||
allowed, err := validateStrUniq(list.AllowedClients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating allowed clients: %w", err)
|
||||
}
|
||||
|
||||
ok, disallowedClients := isUniq(list.DisallowedClients)
|
||||
if !ok {
|
||||
return errDisallowedDup
|
||||
disallowed, err := validateStrUniq(list.DisallowedClients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating disallowed clients: %w", err)
|
||||
}
|
||||
|
||||
ok, _ = isUniq(list.BlockedHosts)
|
||||
if !ok {
|
||||
return errBlockedDup
|
||||
_, err = validateStrUniq(list.BlockedHosts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating blocked hosts: %w", err)
|
||||
}
|
||||
|
||||
if intersect(allowedClients, disallowedClients) {
|
||||
return errIntersect
|
||||
merged := allowed.Merge(disallowed)
|
||||
err = merged.Validate(aghalg.StringIsBefore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStrUniq returns an informative error if clients are not unique.
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
|
||||
uc = make(aghalg.UniqChecker, len(clients))
|
||||
for _, c := range clients {
|
||||
uc.Add(c)
|
||||
}
|
||||
|
||||
return uc, uc.Validate(aghalg.StringIsBefore)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
list := accessListJSON{}
|
||||
list := &accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = validateAccessSet(list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, err.Error())
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
@@ -266,7 +251,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
var a *accessCtx
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -13,12 +12,12 @@ import (
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||
func ValidateClientID(clientID string) (err error) {
|
||||
err = netutil.ValidateDomainNameLabel(clientID)
|
||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||
func ValidateClientID(id string) (err error) {
|
||||
err = netutil.ValidateDomainNameLabel(id)
|
||||
if err != nil {
|
||||
// Replace the domain name label wrapper with our own.
|
||||
return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err))
|
||||
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -34,7 +33,7 @@ func hasLabelSuffix(s, suffix string) (ok bool) {
|
||||
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
|
||||
}
|
||||
|
||||
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
|
||||
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||
// client. When strict is true, and client and host server name don't match,
|
||||
// clientIDFromClientServerName will return an error.
|
||||
@@ -87,22 +86,22 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
|
||||
}
|
||||
|
||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||
return "", fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
return "", fmt.Errorf("clientid check: invalid path %q", origPath)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// Just /dns-query, no client ID.
|
||||
// Just /dns-query, no ClientID.
|
||||
return "", nil
|
||||
case 2:
|
||||
clientID = parts[1]
|
||||
default:
|
||||
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
|
||||
}
|
||||
|
||||
err = ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
@@ -167,24 +166,8 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientID puts the clientID into the DNS context, if there is one.
|
||||
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
clientIDData := s.clientIDCache.Get(key[:])
|
||||
if clientIDData == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
dctx.clientID = string(clientIDData)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTLSConn is a tlsConn for tests.
|
||||
@@ -31,8 +31,8 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
|
||||
|
||||
// testQUICSession is a quicSession for tests.
|
||||
type testQUICSession struct {
|
||||
// Session is embedded here simply to make testQUICSession
|
||||
// a quic.Session without acctually implementing all methods.
|
||||
// Session is embedded here simply to make testQUICSession a quic.Session
|
||||
// without actually implementing all methods.
|
||||
quic.Session
|
||||
|
||||
serverName string
|
||||
@@ -65,7 +65,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: "",
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_no_client_id",
|
||||
name: "tls_no_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "example.com",
|
||||
@@ -78,7 +78,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "" ` +
|
||||
wantErrMsg: `clientid check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
@@ -90,7 +90,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: "",
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_client_id",
|
||||
name: "tls_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
@@ -98,36 +98,36 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_hostname_error",
|
||||
name: "tls_clientid_hostname_error",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.net",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||
wantErrMsg: `clientid check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_invalid_client_id",
|
||||
name: "tls_invalid_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "!!!.example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
wantErrMsg: `clientid check: invalid clientid "!!!": ` +
|
||||
`bad domain name label rune '!'`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
name: "tls_clientid_too_long",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789.example.com`,
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||
wantErrMsg: `clientid check: invalid clientid "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`domain name label is too long: got 72, max 63`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
name: "quic_clientid",
|
||||
proto: proxy.ProtoQUIC,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
@@ -135,12 +135,12 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: "",
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_issue3437",
|
||||
name: "tls_clientid_issue3437",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.myexample.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.myexample.com" ` +
|
||||
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}}
|
||||
@@ -179,13 +179,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -197,22 +191,22 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "no_client_id",
|
||||
name: "no_clientid",
|
||||
path: "/dns-query",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "no_client_id_slash",
|
||||
name: "no_clientid_slash",
|
||||
path: "/dns-query/",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "client_id",
|
||||
name: "clientid",
|
||||
path: "/dns-query/cli",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "client_id_slash",
|
||||
name: "clientid_slash",
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
@@ -220,18 +214,17 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||
wantErrMsg: `clientid check: invalid path "/foo"`,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
wantErrMsg: `clientid check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
}, {
|
||||
name: "invalid_client_id",
|
||||
name: "invalid_clientid",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`bad domain name label rune '!'`,
|
||||
wantErrMsg: `clientid check: invalid clientid "!!!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -250,13 +243,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
clientID, err := clientIDFromDNSContextHTTPS(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@@ -98,10 +99,10 @@ type FilteringConfig struct {
|
||||
AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients
|
||||
DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked
|
||||
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
|
||||
// TrustedProxies is the list of IP addresses and CIDR networks to
|
||||
// detect proxy servers addresses the DoH requests from which should be
|
||||
// handled. The value of nil or an empty slice for this field makes
|
||||
// Proxy not trust any address.
|
||||
// TrustedProxies is the list of IP addresses and CIDR networks to detect
|
||||
// proxy servers addresses the DoH requests from which should be handled.
|
||||
// The value of nil or an empty slice for this field makes Proxy not trust
|
||||
// any address.
|
||||
TrustedProxies []string `yaml:"trusted_proxies"`
|
||||
|
||||
// DNS cache settings
|
||||
@@ -118,7 +119,7 @@ type FilteringConfig struct {
|
||||
|
||||
BogusNXDomain []string `yaml:"bogus_nxdomain"` // transform responses with these IP addresses to NXDOMAIN
|
||||
AAAADisabled bool `yaml:"aaaa_disabled"` // Respond with an empty answer to all AAAA requests
|
||||
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set DNSSEC flag in outcoming DNS request
|
||||
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
|
||||
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
|
||||
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
|
||||
|
||||
@@ -149,8 +150,8 @@ type TLSConfig struct {
|
||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// ServerName is the hostname of the server. Currently, it is only
|
||||
// being used for client ID checking.
|
||||
// ServerName is the hostname of the server. Currently, it is only being
|
||||
// used for ClientID checking.
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate
|
||||
@@ -243,15 +244,15 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
proxyConfig.FastestPingTimeout = s.conf.FastestTimeout.Duration
|
||||
}
|
||||
|
||||
if len(s.conf.BogusNXDomain) > 0 {
|
||||
for _, s := range s.conf.BogusNXDomain {
|
||||
ip := net.ParseIP(s)
|
||||
if ip == nil {
|
||||
log.Error("Invalid bogus IP: %s", s)
|
||||
} else {
|
||||
proxyConfig.BogusNXDomain = append(proxyConfig.BogusNXDomain, ip)
|
||||
}
|
||||
for i, s := range s.conf.BogusNXDomain {
|
||||
subnet, err := netutil.ParseSubnet(s)
|
||||
if err != nil {
|
||||
log.Error("subnet at index %d: %s", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
proxyConfig.BogusNXDomain = append(proxyConfig.BogusNXDomain, subnet)
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
@@ -426,6 +427,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||
|
||||
proxyConfig.TLSConfig = &tls.Config{
|
||||
GetCertificate: s.onGetCertificate,
|
||||
CipherSuites: aghtls.SaferCipherSuites(),
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,36 +17,45 @@ import (
|
||||
|
||||
// To transfer information between modules
|
||||
type dnsContext struct {
|
||||
// TODO(a.garipov): Remove this and rewrite processors to be methods of
|
||||
// *Server instead.
|
||||
srv *Server
|
||||
proxyCtx *proxy.DNSContext
|
||||
|
||||
// setts are the filtering settings for the client.
|
||||
setts *filtering.Settings
|
||||
startTime time.Time
|
||||
result *filtering.Result
|
||||
setts *filtering.Settings
|
||||
|
||||
result *filtering.Result
|
||||
// origResp is the response received from upstream. It is set when the
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
|
||||
// unreversedReqIP stores an IP address obtained from PTR request if it
|
||||
// was successfully parsed.
|
||||
// parsed successfully and belongs to one of locally-served IP ranges as per
|
||||
// RFC 6303.
|
||||
unreversedReqIP net.IP
|
||||
|
||||
// err is the error returned from a processing function.
|
||||
err error
|
||||
// clientID is the clientID from DoH, DoQ, or DoT, if provided.
|
||||
|
||||
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
|
||||
clientID string
|
||||
|
||||
// origQuestion is the question received from the client. It is set
|
||||
// when the request is modified by rewrites.
|
||||
origQuestion dns.Question
|
||||
|
||||
// startTime is the time at which the processing of the request has started.
|
||||
startTime time.Time
|
||||
|
||||
// protectionEnabled shows if the filtering is enabled, and if the
|
||||
// server's DNS filter is ready.
|
||||
protectionEnabled bool
|
||||
|
||||
// responseFromUpstream shows if the response is received from the
|
||||
// upstream servers.
|
||||
responseFromUpstream bool
|
||||
// origReqDNSSEC shows if the DNSSEC flag in the original request from
|
||||
// the client is set.
|
||||
origReqDNSSEC bool
|
||||
|
||||
// responseAD shows if the response had the AD bit set.
|
||||
responseAD bool
|
||||
|
||||
// isLocalClient shows if client's IP address is from locally-served
|
||||
// network.
|
||||
isLocalClient bool
|
||||
@@ -69,7 +79,6 @@ const (
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{
|
||||
srv: s,
|
||||
proxyCtx: d,
|
||||
result: &filtering.Result{},
|
||||
startTime: time.Now(),
|
||||
@@ -84,19 +93,17 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
processInitial,
|
||||
s.processInitial,
|
||||
s.processDetermineLocal,
|
||||
s.processInternalHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
s.processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
processDNSSECAfterResponse,
|
||||
processFilteringAfterResponse,
|
||||
s.processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
processQueryLogsAndStats,
|
||||
s.processQueryLogsAndStats,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(ctx)
|
||||
@@ -134,9 +141,11 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Perform initial checks; process WHOIS & rDNS
|
||||
func processInitial(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches the ctx with some client-specific information.
|
||||
//
|
||||
// TODO(e.burkov): Decompose into less general processors.
|
||||
func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
|
||||
_ = proxy.CheckDisabledAAAARequest(d, true)
|
||||
@@ -155,6 +164,16 @@ func processInitial(ctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Get the client's ID if any. It should be performed before getting
|
||||
// client-specific filtering settings.
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], d.RequestID)
|
||||
ctx.clientID = string(s.clientIDCache.Get(key[:]))
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
ctx.protectionEnabled = s.conf.ProtectionEnabled
|
||||
ctx.setts = s.getClientRequestFilteringSettings(ctx)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -196,9 +215,8 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
ipToHost = netutil.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
// server code.
|
||||
// TODO(a.garipov): Remove this after we're finished with the client
|
||||
// hostname validations in the DHCP server code.
|
||||
err = netutil.ValidateDomainName(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug(
|
||||
@@ -274,14 +292,16 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
req := dctx.proxyCtx.Req
|
||||
q := req.Question[0]
|
||||
|
||||
// Go on processing the AAAA request despite the fact that we don't
|
||||
// support it yet. The expected behavior here is to respond with an
|
||||
// empty asnwer and not NXDOMAIN.
|
||||
// Go on processing the AAAA request despite the fact that we don't support
|
||||
// it yet. The expected behavior here is to respond with an empty answer
|
||||
// and not NXDOMAIN.
|
||||
if q.Qtype != dns.TypeA && q.Qtype != dns.TypeAAAA {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
reqHost := strings.ToLower(q.Name)
|
||||
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
|
||||
// server.
|
||||
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
|
||||
if host == reqHost {
|
||||
return resultCodeSuccess
|
||||
@@ -298,8 +318,8 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
ip, ok := s.hostToIP(host)
|
||||
if !ok {
|
||||
// TODO(e.burkov): Inspect special cases when user want to apply
|
||||
// some rules handled by other processors to the hosts with TLD.
|
||||
// TODO(e.burkov): Inspect special cases when user want to apply some
|
||||
// rules handled by other processors to the hosts with TLD.
|
||||
d.Res = s.genNXDomain(req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -333,34 +353,51 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
ip, err := netutil.IPFromReversedAddr(q.Name)
|
||||
if err != nil {
|
||||
log.Debug("dns: reversed addr: %s", err)
|
||||
log.Debug("dns: parsing reversed addr: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
// DNS-Based Service Discovery uses PTR records having not an ARPA
|
||||
// format of the domain name in question. Those shouldn't be
|
||||
// invalidated. See http://www.dns-sd.org/ServerStaticSetup.html and
|
||||
// RFC 2782.
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
if err = netutil.ValidateSRVDomainName(name); err != nil {
|
||||
log.Debug("dns: validating service domain: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
log.Debug("dns: request is for a service domain")
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at
|
||||
// least don't need to be inaccessible externally.
|
||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !ctx.isLocalClient {
|
||||
log.Debug("dns: %q requests for internal ip", d.Addr)
|
||||
d.Res = s.genNXDomain(req)
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !ctx.isLocalClient {
|
||||
log.Debug("dns: %q requests an internal ip", d.Addr)
|
||||
d.Res = s.genNXDomain(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Do not perform unreversing ever again.
|
||||
ctx.unreversedReqIP = ip
|
||||
|
||||
// Disable redundant filtering.
|
||||
filterSetts := s.getClientRequestFilteringSettings(ctx)
|
||||
filterSetts.ParentalEnabled = false
|
||||
filterSetts.SafeBrowsingEnabled = false
|
||||
filterSetts.SafeSearchEnabled = false
|
||||
filterSetts.ServicesRules = nil
|
||||
ctx.setts = filterSetts
|
||||
// There is no need to filter request from external addresses since this
|
||||
// code is only executed when the request is for locally-served ARPA
|
||||
// hostname so disable redundant filters.
|
||||
ctx.setts.ParentalEnabled = false
|
||||
ctx.setts.SafeBrowsingEnabled = false
|
||||
ctx.setts.SafeSearchEnabled = false
|
||||
ctx.setts.ServicesRules = nil
|
||||
|
||||
// Nothing to restrict.
|
||||
return resultCodeSuccess
|
||||
@@ -468,29 +505,21 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess // response is already set - nothing to do
|
||||
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
if ctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if !ctx.protectionEnabled {
|
||||
if s.dnsFilter == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if ctx.setts == nil {
|
||||
ctx.setts = s.getClientRequestFilteringSettings(ctx)
|
||||
}
|
||||
|
||||
var err error
|
||||
ctx.result, err = s.filterDNSRequest(ctx)
|
||||
if err != nil {
|
||||
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
|
||||
ctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
@@ -509,147 +538,105 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess // response is already set - nothing to do
|
||||
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
// The response has already been set.
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
// Use the clientID first, since it has a higher priority.
|
||||
id := stringutil.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
|
||||
if pctx.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
// Use the ClientID first, since it has a higher priority.
|
||||
id := stringutil.Coalesce(dctx.clientID, ipStringFromAddr(pctx.Addr))
|
||||
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
|
||||
if err != nil {
|
||||
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
||||
} else if upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", id)
|
||||
d.CustomUpstreamConfig = upsConf
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
}
|
||||
|
||||
req := d.Req
|
||||
req := pctx.Req
|
||||
origReqAD := false
|
||||
if s.conf.EnableDNSSEC {
|
||||
opt := req.IsEdns0()
|
||||
if opt == nil {
|
||||
log.Debug("dns: adding OPT record with DNSSEC flag")
|
||||
req.SetEdns0(4096, true)
|
||||
} else if !opt.Do() {
|
||||
opt.SetDo(true)
|
||||
if req.AuthenticatedData {
|
||||
origReqAD = true
|
||||
} else {
|
||||
ctx.origReqDNSSEC = true
|
||||
req.AuthenticatedData = true
|
||||
}
|
||||
}
|
||||
|
||||
// Process the request further since it wasn't filtered.
|
||||
|
||||
prx := s.proxy()
|
||||
if prx == nil {
|
||||
ctx.err = srvClosedErr
|
||||
dctx.err = srvClosedErr
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
if ctx.err = prx.Resolve(d); ctx.err != nil {
|
||||
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
ctx.responseFromUpstream = true
|
||||
dctx.responseFromUpstream = true
|
||||
dctx.responseAD = pctx.Res.AuthenticatedData
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Process DNSSEC after response from upstream server
|
||||
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
|
||||
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
|
||||
!ctx.srv.conf.EnableDNSSEC {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !ctx.origReqDNSSEC {
|
||||
optResp := d.Res.IsEdns0()
|
||||
if optResp != nil && !optResp.Do() {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Remove RRSIG records from response
|
||||
// because there is no DO flag in the original request from client,
|
||||
// but we have EnableDNSSEC set, so we have set DO flag ourselves,
|
||||
// and now we have to clean up the DNS records our client didn't ask for.
|
||||
|
||||
answers := []dns.RR{}
|
||||
for _, a := range d.Res.Answer {
|
||||
switch a.(type) {
|
||||
case *dns.RRSIG:
|
||||
log.Debug("Removing RRSIG record from response: %v", a)
|
||||
default:
|
||||
answers = append(answers, a)
|
||||
}
|
||||
}
|
||||
d.Res.Answer = answers
|
||||
|
||||
answers = []dns.RR{}
|
||||
for _, a := range d.Res.Ns {
|
||||
switch a.(type) {
|
||||
case *dns.RRSIG:
|
||||
log.Debug("Removing RRSIG record from response: %v", a)
|
||||
default:
|
||||
answers = append(answers, a)
|
||||
}
|
||||
}
|
||||
d.Res.Ns = answers
|
||||
if s.conf.EnableDNSSEC && !origReqAD {
|
||||
pctx.Req.AuthenticatedData = false
|
||||
pctx.Res.AuthenticatedData = false
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
res := ctx.result
|
||||
var err error
|
||||
|
||||
switch res.Reason {
|
||||
case filtering.Rewritten,
|
||||
switch res := ctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
// Go on.
|
||||
case
|
||||
filtering.Rewritten,
|
||||
filtering.RewrittenRule:
|
||||
|
||||
if len(ctx.origQuestion.Name) == 0 {
|
||||
// origQuestion is set in case we get only CNAME without IP from rewrites table
|
||||
// origQuestion is set in case we get only CNAME without IP from
|
||||
// rewrites table.
|
||||
break
|
||||
}
|
||||
|
||||
d.Req.Question[0] = ctx.origQuestion
|
||||
d.Res.Question[0] = ctx.origQuestion
|
||||
|
||||
if len(d.Res.Answer) != 0 {
|
||||
answer := []dns.RR{}
|
||||
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
|
||||
answer = append(answer, d.Res.Answer...)
|
||||
d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
|
||||
if len(d.Res.Answer) > 0 {
|
||||
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
|
||||
d.Res.Answer = answer
|
||||
}
|
||||
|
||||
case filtering.NotFilteredAllowList:
|
||||
// nothing
|
||||
|
||||
default:
|
||||
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for
|
||||
!ctx.responseFromUpstream { // only check response if it's from an upstream server
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
|
||||
break
|
||||
}
|
||||
origResp2 := d.Res
|
||||
ctx.result, err = s.filterDNSResponse(ctx)
|
||||
|
||||
origResp := d.Res
|
||||
result, err := s.filterDNSResponse(ctx)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
if ctx.result != nil {
|
||||
ctx.origResp = origResp2 // matched by response
|
||||
} else {
|
||||
ctx.result = &filtering.Result{}
|
||||
|
||||
if result != nil {
|
||||
ctx.result = result
|
||||
ctx.origResp = origResp
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.result == nil {
|
||||
ctx.result = &filtering.Result{}
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
ups := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
|
||||
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
|
||||
@@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, &aghtest.TestUpstream{
|
||||
}, &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
reqAddr: {locDomain},
|
||||
},
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU client ID
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
@@ -79,14 +79,17 @@ type Server struct {
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
recDetector *recursionDetector
|
||||
|
||||
// anonymizer masks the client's IP addresses if needed.
|
||||
anonymizer *aghnet.IPMut
|
||||
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost *netutil.IPMap
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for clientIDs that were
|
||||
// extracted during the BeforeRequestHandler stage.
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// DNS proxy instance for internal usage
|
||||
@@ -113,6 +116,7 @@ type DNSCreateParams struct {
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
SubnetDetector *aghnet.SubnetDetector
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
@@ -150,6 +154,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
|
||||
}
|
||||
|
||||
if p.Anonymizer == nil {
|
||||
p.Anonymizer = aghnet.NewIPMut(nil)
|
||||
}
|
||||
s = &Server{
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
@@ -161,6 +168,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
anonymizer: p.Anonymizer,
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
@@ -435,7 +443,7 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -551,7 +559,7 @@ func (s *Server) IsRunning() bool {
|
||||
return s.isRunning
|
||||
}
|
||||
|
||||
// srvClosedErr is returned when the method can't complete without unacessible
|
||||
// srvClosedErr is returned when the method can't complete without inaccessible
|
||||
// data from the closing server.
|
||||
const srvClosedErr errors.Error = "server is closed"
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -44,10 +46,7 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
err := s.Start()
|
||||
require.NoErrorf(t, err, "failed to start server: %s", err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
serr := s.Stop()
|
||||
require.NoErrorf(t, serr, "dns server failed to stop: %s", serr)
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
}
|
||||
|
||||
func createTestServer(
|
||||
@@ -90,7 +89,7 @@ func createTestServer(
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
if localUps != nil {
|
||||
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
||||
s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
|
||||
s.conf.UsePrivateRDNS = true
|
||||
}
|
||||
|
||||
@@ -248,7 +247,7 @@ func TestServer(t *testing.T) {
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -317,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -340,7 +339,7 @@ func TestDoTServer(t *testing.T) {
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -370,7 +369,7 @@ func TestDoQServer(t *testing.T) {
|
||||
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -414,7 +413,7 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
|
||||
},
|
||||
@@ -553,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
ups := &aghtest.Upstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
@@ -581,9 +580,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
|
||||
var testCNAMEs = map[string]string{
|
||||
"badhost.": "NULL.example.org.",
|
||||
"whitelist.example.org.": "NULL.example.org.",
|
||||
var testCNAMEs = map[string][]string{
|
||||
"badhost.": {"NULL.example.org."},
|
||||
"whitelist.example.org.": {"NULL.example.org."},
|
||||
}
|
||||
|
||||
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
|
||||
@@ -597,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
}, nil)
|
||||
testUpstm := &aghtest.TestUpstream{
|
||||
testUpstm := &aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
IPv6: nil,
|
||||
@@ -631,7 +630,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
},
|
||||
@@ -641,14 +640,17 @@ func TestBlockCNAME(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want bool
|
||||
}{{
|
||||
name: "block_request",
|
||||
host: "badhost.",
|
||||
// 'badhost' has a canonical name 'NULL.example.org' which is
|
||||
// blocked by filters: response is blocked.
|
||||
want: true,
|
||||
}, {
|
||||
name: "allowed",
|
||||
host: "whitelist.example.org.",
|
||||
// 'whitelist.example.org' has a canonical name
|
||||
// 'NULL.example.org' which is blocked by filters
|
||||
@@ -656,6 +658,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
// response isn't blocked.
|
||||
want: false,
|
||||
}, {
|
||||
name: "block_response",
|
||||
host: "example.org.",
|
||||
// 'example.org' has a canonical name 'cname1' with IP
|
||||
// 127.0.0.255 which is blocked by filters: response is blocked.
|
||||
@@ -663,9 +666,9 @@ func TestBlockCNAME(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("block_cname_"+tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reply, err := dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -675,7 +678,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
|
||||
ans := reply.Answer[0]
|
||||
a, ok := ans.(*dns.A)
|
||||
require.Truef(t, ok, "got %T", ans)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.True(t, a.A.IsUnspecified())
|
||||
}
|
||||
@@ -696,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
&aghtest.Upstream{
|
||||
CName: testCNAMEs,
|
||||
IPv4: testIPv4,
|
||||
},
|
||||
@@ -893,7 +896,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
c := &filtering.Config{
|
||||
Rewrites: []filtering.RewriteEntry{{
|
||||
Rewrites: []*filtering.LegacyRewrite{{
|
||||
Domain: "test.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
@@ -908,6 +911,7 @@ func TestRewrite(t *testing.T) {
|
||||
}},
|
||||
}
|
||||
f := filtering.New(c, nil)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
@@ -931,9 +935,9 @@ func TestRewrite(t *testing.T) {
|
||||
}))
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
CName: map[string]string{
|
||||
"example.org": "somename",
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"example.org": {"somename"},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"example.org.": {{4, 3, 2, 1}},
|
||||
@@ -944,45 +948,56 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
req := createTestMessageWithType("test.com.", dns.TypeA)
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
subTestFunc := func(t *testing.T) {
|
||||
req := createTestMessageWithType("test.com.", dns.TypeA)
|
||||
reply, eerr := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
require.Len(t, reply.Answer, 1)
|
||||
require.Len(t, reply.Answer, 1)
|
||||
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
|
||||
|
||||
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
require.Len(t, reply.Answer, 2)
|
||||
require.Len(t, reply.Answer, 2)
|
||||
|
||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
// The original question is restored.
|
||||
require.Len(t, reply.Question, 1)
|
||||
// The original question is restored.
|
||||
require.Len(t, reply.Question, 1)
|
||||
|
||||
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
|
||||
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
|
||||
|
||||
require.Len(t, reply.Answer, 2)
|
||||
require.Len(t, reply.Answer, 2)
|
||||
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
}
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
val := protect
|
||||
conf := s.getDNSConfig()
|
||||
conf.ProtectionEnabled = &val
|
||||
s.setConfig(conf)
|
||||
|
||||
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
@@ -1036,9 +1051,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
})
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
|
||||
@@ -1057,23 +1070,40 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
c := filtering.Config{
|
||||
EtcHosts: &aghnet.EtcHostsContainer{},
|
||||
// Prepare test hosts file.
|
||||
|
||||
const hostsFilename = "hosts"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
hostsFilename: &fstest.MapFile{Data: []byte(`
|
||||
127.0.0.1 host # comment
|
||||
::1 localhost#comment
|
||||
`)},
|
||||
}
|
||||
|
||||
// Prepare test hosts file.
|
||||
hf, err := os.CreateTemp("", "")
|
||||
var eventsCalledCounter uint32
|
||||
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return nil
|
||||
},
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, hostsFilename, name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}, hostsFilename)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, hf.Close())
|
||||
assert.NoError(t, os.Remove(hf.Name()))
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
})
|
||||
|
||||
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
|
||||
_, _ = hf.WriteString(" ::1 localhost#comment \n")
|
||||
|
||||
c.EtcHosts.Init(hf.Name())
|
||||
t.Cleanup(c.EtcHosts.Close)
|
||||
flt := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
@@ -1083,7 +1113,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: filtering.New(&c, nil),
|
||||
DNSFilter: flt,
|
||||
SubnetDetector: snd,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1091,32 +1121,39 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
})
|
||||
subTestFunc := func(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
resp, eerr := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
|
||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
||||
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "host.", ptr.Ptr)
|
||||
}
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "host.", ptr.Ptr)
|
||||
for _, protect := range []bool{true, false} {
|
||||
val := protect
|
||||
conf := s.getDNSConfig()
|
||||
conf.ProtectionEnabled = &val
|
||||
s.setConfig(conf)
|
||||
|
||||
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
@@ -1154,23 +1191,18 @@ func TestNewServer(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewServer(tc.in)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.TestUpstream{
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.TestUpstream{
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
|
||||
@@ -61,8 +61,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
||||
})
|
||||
|
||||
@@ -72,7 +72,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
assert.Empty(t, d.Res.Answer)
|
||||
})
|
||||
@@ -83,7 +84,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -96,7 +98,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -109,7 +112,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -122,7 +126,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -135,7 +140,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -152,7 +158,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -171,7 +178,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
@@ -190,7 +198,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
|
||||
@@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
|
||||
// the client's IP address and ID, if any, from ctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
setts.ProtectionEnabled = ctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(ip, ctx.clientID, &setts)
|
||||
@@ -65,42 +66,23 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
|
||||
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||
d := ctx.proxyCtx
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts)
|
||||
if err != nil {
|
||||
// Return immediately if there's an error
|
||||
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err)
|
||||
} else if res.IsFiltered {
|
||||
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
|
||||
q := req.Question[0]
|
||||
host := strings.TrimSuffix(q.Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("failed to check host %q: %w", host, err)
|
||||
case res.IsFiltered:
|
||||
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
|
||||
d.Res = s.genDNSFilterMessage(d, &res)
|
||||
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
|
||||
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
|
||||
res.CanonName != "" &&
|
||||
len(res.IPList) == 0 {
|
||||
// Resolve the new canonical name, not the original host
|
||||
// name. The original question is readded in
|
||||
// processFilteringAfterResponse.
|
||||
ctx.origQuestion = req.Question[0]
|
||||
len(res.IPList) == 0:
|
||||
// Resolve the new canonical name, not the original host name. The
|
||||
// original question is readded in processFilteringAfterResponse.
|
||||
ctx.origQuestion = q
|
||||
req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
|
||||
resp := s.makeResponse(req)
|
||||
for _, h := range res.ReverseHosts {
|
||||
hdr := dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
|
||||
ptr := &dns.PTR{
|
||||
Hdr: hdr,
|
||||
Ptr: h,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ptr)
|
||||
}
|
||||
|
||||
d.Res = resp
|
||||
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) {
|
||||
case res.Reason == filtering.Rewritten:
|
||||
resp := s.makeResponse(req)
|
||||
|
||||
name := host
|
||||
@@ -110,11 +92,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||
}
|
||||
|
||||
for _, ip := range res.IPList {
|
||||
if req.Question[0].Qtype == dns.TypeA {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
a := s.genAnswerA(req, ip.To4())
|
||||
a.Hdr.Name = dns.Fqdn(name)
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
} else if req.Question[0].Qtype == dns.TypeAAAA {
|
||||
case dns.TypeAAAA:
|
||||
a := s.genAnswerAAAA(req, ip)
|
||||
a.Hdr.Name = dns.Fqdn(name)
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
@@ -122,9 +105,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||
}
|
||||
|
||||
d.Res = resp
|
||||
} else if res.Reason == filtering.RewrittenRule {
|
||||
err = s.filterDNSRewrite(req, res, d)
|
||||
if err != nil {
|
||||
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
|
||||
if err = s.filterDNSRewrite(req, res, d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -134,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
|
||||
|
||||
// checkHostRules checks the host against filters. It is safe for concurrent
|
||||
// use.
|
||||
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) (
|
||||
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
|
||||
r *filtering.Result,
|
||||
err error,
|
||||
) {
|
||||
@@ -146,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||
}
|
||||
|
||||
var res filtering.Result
|
||||
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts)
|
||||
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,32 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// If response contains CNAME, A or AAAA records, we apply filtering to each
|
||||
// canonical host name or IP address. If this is a match, we set a new response
|
||||
// in d.Res and return.
|
||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
|
||||
// filterDNSResponse checks each resource record of the response's answer
|
||||
// section from ctx and returns a non-nil res if at least one of canonnical
|
||||
// names or IP addresses in it matches the filtering rules.
|
||||
func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
|
||||
d := ctx.proxyCtx
|
||||
setts := ctx.setts
|
||||
if !setts.FilteringEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, a := range d.Res.Answer {
|
||||
host := ""
|
||||
|
||||
switch v := a.(type) {
|
||||
var rrtype uint16
|
||||
switch a := a.(type) {
|
||||
case *dns.CNAME:
|
||||
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name)
|
||||
host = strings.TrimSuffix(v.Target, ".")
|
||||
|
||||
host = strings.TrimSuffix(a.Target, ".")
|
||||
rrtype = dns.TypeCNAME
|
||||
case *dns.A:
|
||||
host = v.A.String()
|
||||
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name)
|
||||
|
||||
host = a.A.String()
|
||||
rrtype = dns.TypeA
|
||||
case *dns.AAAA:
|
||||
host = v.AAAA.String()
|
||||
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name)
|
||||
|
||||
host = a.AAAA.String()
|
||||
rrtype = dns.TypeAAAA
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
|
||||
res, err = s.checkHostRules(host, rrtype, setts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil {
|
||||
|
||||
159
internal/dnsforward/filter_test.go
Normal file
159
internal/dnsforward/filter_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
rules := `
|
||||
||blocked.domain^
|
||||
@@||allowed.domain^
|
||||
||cname.specific^$dnstype=~CNAME
|
||||
||0.0.0.1^$dnstype=~A
|
||||
||::1^$dnstype=~AAAA
|
||||
`
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: BlockingModeDefault,
|
||||
},
|
||||
}
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf = forwardConf
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"cname.exception.": {"cname.specific."},
|
||||
"should.block.": {"blocked.domain."},
|
||||
"allowed.first.": {"allowed.domain.", "blocked.domain."},
|
||||
"blocked.first.": {"blocked.domain.", "allowed.domain."},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"a.exception.": {{0, 0, 0, 1}},
|
||||
},
|
||||
IPv6: map[string][]net.IP{
|
||||
"aaaa.exception.": {net.ParseIP("::1")},
|
||||
},
|
||||
},
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantAns []dns.RR
|
||||
}{{
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
wantAns: []dns.RR{&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "cname.exception.",
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
Target: "cname.specific.",
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "should.block.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "a.exception.",
|
||||
Rrtype: dns.TypeA,
|
||||
},
|
||||
A: net.IP{0, 0, 0, 1},
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
wantAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "aaaa.exception.",
|
||||
Rrtype: dns.TypeAAAA,
|
||||
},
|
||||
AAAA: net.ParseIP("::1"),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "allowed.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "blocked.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -18,12 +20,6 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("dns: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
type dnsConfig struct {
|
||||
Upstreams *[]string `json:"upstream_dns"`
|
||||
UpstreamsFile *string `json:"upstream_dns_file"`
|
||||
@@ -47,7 +43,7 @@ type dnsConfig struct {
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
}
|
||||
|
||||
func (s *Server) getDNSConfig() dnsConfig {
|
||||
func (s *Server) getDNSConfig() (c *dnsConfig) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
@@ -76,7 +72,7 @@ func (s *Server) getDNSConfig() dnsConfig {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
return dnsConfig{
|
||||
return &dnsConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
@@ -112,14 +108,15 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// since there is no need to omit it while decoding from JSON.
|
||||
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}{
|
||||
dnsConfig: s.getDNSConfig(),
|
||||
dnsConfig: *s.getDNSConfig(),
|
||||
DefautLocalPTRUpstreams: defLocalPTRUps,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err = json.NewEncoder(w).Encode(resp); err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encoder: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encoder: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -143,39 +140,63 @@ func (req *dnsConfig) checkBlockingMode() bool {
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkUpstreamsMode() bool {
|
||||
if req.UpstreamMode == nil {
|
||||
return true
|
||||
}
|
||||
valid := []string{"", "fastest_addr", "parallel"}
|
||||
|
||||
for _, valid := range []string{
|
||||
"",
|
||||
"fastest_addr",
|
||||
"parallel",
|
||||
} {
|
||||
if *req.UpstreamMode == valid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, boot := range *req.Bootstraps {
|
||||
if boot == "" {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
if _, err = upstream.NewResolver(b, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkBlockingMode():
|
||||
return errors.Error("blocking_mode: incorrect value")
|
||||
case !req.checkUpstreamsMode():
|
||||
return errors.Error("upstream_mode: incorrect value")
|
||||
case !req.checkCacheTTL():
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkCacheTTL() bool {
|
||||
@@ -195,37 +216,18 @@ func (req *dnsConfig) checkCacheTTL() bool {
|
||||
}
|
||||
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := dnsConfig{}
|
||||
dec := json.NewDecoder(r.Body)
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json Encode: %s", err)
|
||||
req := &dnsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.Upstreams != nil {
|
||||
if err := ValidateUpstreams(*req.Upstreams); err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
err = req.validate(s.subnetDetector)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
if errBoot, err := req.checkBootstrap(); err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", errBoot, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !req.checkBlockingMode() {
|
||||
httpError(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
|
||||
return
|
||||
}
|
||||
|
||||
if !req.checkUpstreamsMode() {
|
||||
httpError(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
|
||||
return
|
||||
}
|
||||
|
||||
if !req.checkCacheTTL() {
|
||||
httpError(r, w, http.StatusBadRequest, "cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,14 +235,14 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
s.conf.ConfigModified()
|
||||
|
||||
if restart {
|
||||
if err := s.Reconfigure(nil); err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
err = s.Reconfigure(nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) {
|
||||
if dc.Upstreams != nil {
|
||||
s.conf.UpstreamDNS = *dc.Upstreams
|
||||
restart = true
|
||||
@@ -261,9 +263,9 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.RateLimit != nil {
|
||||
restart = restart || s.conf.Ratelimit != *dc.RateLimit
|
||||
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
|
||||
s.conf.Ratelimit = *dc.RateLimit
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.EDNSCSEnabled != nil {
|
||||
@@ -294,7 +296,7 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
return restart
|
||||
}
|
||||
|
||||
func (s *Server) setConfig(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfig(dc *dnsConfig) (restart bool) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
@@ -341,103 +343,180 @@ type upstreamJSON struct {
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true of the string starts with a "#" character or is
|
||||
// an empty string. This function is useful for filtering out non-upstream
|
||||
// lines from upstream configs.
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// LocalNetChecker is used to check if the IP address belongs to a local
|
||||
// network.
|
||||
type LocalNetChecker interface {
|
||||
// IsLocallyServedNetwork returns true if ip is contained in any of address
|
||||
// registries defined by RFC 6303.
|
||||
IsLocallyServedNetwork(ip net.IP) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
// No need to validate comments
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
// Consider this case valid because defaultDNS will be used
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. lnc must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
var ok bool
|
||||
ok, err = validateUpstream(u)
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
defaultUpstreamFound = ok
|
||||
if !lnc.IsLocallyServedNetwork(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
return fmt.Errorf("no default upstreams specified")
|
||||
if len(errs) > 0 {
|
||||
return errors.List("checking domain-specific upstreams", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
|
||||
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
|
||||
|
||||
func validateUpstream(u string) (bool, error) {
|
||||
func validateUpstream(u string) (useDefault bool, err error) {
|
||||
// Check if the user tries to specify upstream for domain.
|
||||
u, useDefault, err := separateUpstream(u)
|
||||
var isDomainSpec bool
|
||||
u, isDomainSpec, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
return useDefault, err
|
||||
return !isDomainSpec, err
|
||||
}
|
||||
|
||||
// The special server address '#' means "use the default servers"
|
||||
if u == "#" && !useDefault {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return useDefault, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return error if the upstream contains '://' without any valid protocol
|
||||
if strings.Contains(u, "://") {
|
||||
return useDefault, fmt.Errorf("wrong protocol")
|
||||
return useDefault, errors.Error("wrong protocol")
|
||||
}
|
||||
|
||||
// Check if upstream is valid plain DNS
|
||||
return useDefault, checkPlainDNS(u)
|
||||
// Check if upstream is either an IP or IP with port.
|
||||
if net.ParseIP(u) != nil {
|
||||
return useDefault, nil
|
||||
} else if _, err = netutil.ParseIPPort(u); err != nil {
|
||||
return useDefault, err
|
||||
}
|
||||
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// useDefault is true when a default upstream must be used.
|
||||
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
|
||||
|
||||
// isDomainSpec is true when the upstream is domains-specific.
|
||||
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, true, nil
|
||||
return upstreamStr, false, nil
|
||||
}
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
if len(parts) != 2 {
|
||||
return "", false, errors.Error("duplicated separator")
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return "", false, errors.Error("missing separator")
|
||||
default:
|
||||
return "", true, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
domains := parts[0]
|
||||
upstream = parts[1]
|
||||
var domains string
|
||||
domains, upstream = parts[0], parts[1]
|
||||
for i, host := range strings.Split(domains, "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
@@ -445,36 +524,11 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
|
||||
|
||||
err = netutil.ValidateDomainName(host)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return upstream, false, nil
|
||||
}
|
||||
|
||||
// checkPlainDNS checks if host is plain DNS
|
||||
func checkPlainDNS(upstream string) error {
|
||||
// Check if host is ip without port
|
||||
if net.ParseIP(upstream) != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if host is ip with port
|
||||
ip, port, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if net.ParseIP(ip) == nil {
|
||||
return fmt.Errorf("%s is not a valid IP", ip)
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(port, 0, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not a valid port: %w", port, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return upstream, true, nil
|
||||
}
|
||||
|
||||
// excFunc is a signature of function to check if upstream exchanges correctly.
|
||||
@@ -502,12 +556,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
|
||||
if len(reply.Answer) != 1 {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
|
||||
if t, ok := reply.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -542,7 +592,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
|
||||
// Separate upstream from domains list.
|
||||
var useDefault bool
|
||||
if input, useDefault, err = separateUpstream(input); err != nil {
|
||||
if useDefault, err = validateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
@@ -551,7 +601,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = validateUpstream(input); err != nil {
|
||||
if input, _, err = separateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
@@ -559,7 +609,8 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
log.Debug("checking if dns server %q works...", input)
|
||||
log.Debug("checking if upstream %s works", input)
|
||||
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
@@ -573,7 +624,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
|
||||
}
|
||||
|
||||
log.Debug("dns %s works OK", input)
|
||||
log.Debug("upstream %s is ok", input)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -582,7 +633,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req := &upstreamJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -607,9 +658,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc)
|
||||
if err != nil {
|
||||
log.Info("%v", err)
|
||||
// TODO(e.burkov): If passed upstream have already
|
||||
// written an error above, we rewriting the error for
|
||||
// it. These cases should be handled properly instead.
|
||||
// TODO(e.burkov): If passed upstream have already written an error
|
||||
// above, we rewriting the error for it. These cases should be
|
||||
// handled properly instead.
|
||||
result[host] = err.Error()
|
||||
|
||||
continue
|
||||
@@ -620,7 +671,13 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jsonVal, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal status json: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -628,9 +685,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
|
||||
return
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,12 +696,12 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// -> dnsforward.handleDNSRequest
|
||||
func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil {
|
||||
httpError(r, w, http.StatusNotFound, "Not Found")
|
||||
aghhttp.Error(r, w, http.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
if !s.IsRunning() {
|
||||
httpError(r, w, http.StatusInternalServerError, "dns server is not running")
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -38,9 +39,7 @@ func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
|
||||
var f *os.File
|
||||
f, err := os.Open(filepath.Join("testdata", casesFileName))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, f.Close())
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, f.Close)
|
||||
|
||||
err = json.NewDecoder(f).Decode(cases)
|
||||
require.NoError(t, err)
|
||||
@@ -69,10 +68,8 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &fakeSystemResolvers{}
|
||||
|
||||
require.Nil(t, s.Start())
|
||||
t.Cleanup(func() {
|
||||
require.Nil(t, s.Stop())
|
||||
})
|
||||
require.NoError(t, s.Start())
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
@@ -147,10 +144,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
defaultConf := s.conf
|
||||
|
||||
err := s.Start()
|
||||
assert.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, s.Stop())
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -189,12 +184,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `wrong upstreams specification: address !!!: ` +
|
||||
`missing port in address`,
|
||||
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
wantSet: `a can not be used as bootstrap dns cause: ` +
|
||||
`invalid bootstrap server address: ` +
|
||||
wantSet: `checking bootstrap a: invalid address: ` +
|
||||
`Resolver a is not eligible to be a bootstrap DNS server`,
|
||||
}, {
|
||||
name: "cache_bad_ttl",
|
||||
@@ -205,6 +199,10 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "local_ptr_upstreams_good",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "local_ptr_upstreams_bad",
|
||||
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "non.arpa": not a reversed ip network`,
|
||||
}, {
|
||||
name: "local_ptr_upstreams_null",
|
||||
wantSet: "",
|
||||
@@ -221,14 +219,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
s.conf = defaultConf
|
||||
})
|
||||
t.Cleanup(func() { s.conf = defaultConf })
|
||||
|
||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "http://example.com", rBody)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleSetConfig(w, r)
|
||||
assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n"))
|
||||
@@ -242,130 +238,145 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsCommentOrEmpty(t *testing.T) {
|
||||
assert.True(t, IsCommentOrEmpty(""))
|
||||
assert.True(t, IsCommentOrEmpty("# comment"))
|
||||
assert.False(t, IsCommentOrEmpty("1.2.3.4"))
|
||||
for _, tc := range []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
str string
|
||||
}{{
|
||||
want: assert.True,
|
||||
str: "",
|
||||
}, {
|
||||
want: assert.True,
|
||||
str: "# comment",
|
||||
}, {
|
||||
want: assert.False,
|
||||
str: "1.2.3.4",
|
||||
}} {
|
||||
tc.want(t, IsCommentOrEmpty(tc.str))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Rewrite to check the actual error messages.
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantDef assert.BoolAssertionFunc
|
||||
name string
|
||||
upstream string
|
||||
valid bool
|
||||
wantDef bool
|
||||
wantErr string
|
||||
}{{
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `wrong protocol`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_host",
|
||||
upstream: "udp://dns.google",
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_ip",
|
||||
upstream: "udp://8.8.8.8",
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "idna",
|
||||
upstream: "[/пример.рф/]8.8.8.8",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "bad_domain",
|
||||
upstream: "[/!/]8.8.8.8",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
require.Equal(t, tc.valid, err == nil)
|
||||
if tc.valid {
|
||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
tc.wantDef(t, defaultUpstream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
wantErr string
|
||||
set []string
|
||||
wantNil bool
|
||||
}{{
|
||||
name: "empty",
|
||||
msg: "empty upstreams array should be valid",
|
||||
wantErr: ``,
|
||||
set: nil,
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "comment",
|
||||
msg: "comments should not be validated",
|
||||
wantErr: ``,
|
||||
set: []string{"# comment"},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
msg: "there is no default upstream",
|
||||
name: "valid_no_default",
|
||||
wantErr: `no default upstreams specified`,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
@@ -373,10 +384,9 @@ func TestValidateUpstreamsSet(t *testing.T) {
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
wantNil: false,
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||
name: "valid_with_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
@@ -385,19 +395,65 @@ func TestValidateUpstreamsSet(t *testing.T) {
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"8.8.8.8",
|
||||
},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "invalid",
|
||||
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
wantNil: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreams(tc.set)
|
||||
|
||||
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
u string
|
||||
}{{
|
||||
name: "success_address",
|
||||
wantErr: ``,
|
||||
u: "[/1.0.0.127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "success_subnet",
|
||||
wantErr: ``,
|
||||
u: "[/127.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "not_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`bad arpa domain name "hello.world": not a reversed ip network`,
|
||||
u: "[/hello.world/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_address",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/1.2.3.4.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "non-private_arpa_subnet",
|
||||
wantErr: `checking domain-specific upstreams: ` +
|
||||
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
|
||||
u: "[/128.in-addr.arpa/]#",
|
||||
}, {
|
||||
name: "several_bad",
|
||||
wantErr: `checking domain-specific upstreams: 2 errors: ` +
|
||||
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
|
||||
`"bad arpa domain name \"non.arpa\": not a reversed ip network"`,
|
||||
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = ValidateUpstreamsPrivate(set, snd)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,20 +24,19 @@ type ipsetCtx struct {
|
||||
func (c *ipsetCtx) init(ipsetConf []string) (err error) {
|
||||
c.ipsetMgr, err = aghnet.NewIpsetManager(ipsetConf)
|
||||
if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
|
||||
// ipset cannot currently be initialized if the server was
|
||||
// installed from Snap or when the user or the binary doesn't
|
||||
// have the required permissions, or when the kernel doesn't
|
||||
// support netfilter.
|
||||
// ipset cannot currently be initialized if the server was installed
|
||||
// from Snap or when the user or the binary doesn't have the required
|
||||
// permissions, or when the kernel doesn't support netfilter.
|
||||
//
|
||||
// Log and go on.
|
||||
//
|
||||
// TODO(a.garipov): The Snap problem can probably be solved if
|
||||
// we add the netlink-connector interface plug.
|
||||
log.Info("warning: cannot initialize ipset: %s", err)
|
||||
// TODO(a.garipov): The Snap problem can probably be solved if we add
|
||||
// the netlink-connector interface plug.
|
||||
log.Info("ipset: warning: cannot initialize: %s", err)
|
||||
|
||||
return nil
|
||||
} else if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) {
|
||||
log.Info("warning: %s", err)
|
||||
log.Info("ipset: warning: %s", err)
|
||||
|
||||
return nil
|
||||
} else if err != nil {
|
||||
|
||||
@@ -78,8 +78,8 @@ func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDet
|
||||
// msgToSignature converts msg into it's signature represented in bytes.
|
||||
func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
|
||||
// The binary.BigEndian byte order is used everywhere except when the
|
||||
// real machine's endianess is needed.
|
||||
// The binary.BigEndian byte order is used everywhere except when the real
|
||||
// machine's endianness is needed.
|
||||
byteOrder := binary.BigEndian
|
||||
byteOrder.PutUint16(sig[0:], msg.Id)
|
||||
q := msg.Question[0]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -8,15 +9,15 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||
elapsed := time.Since(ctx.startTime)
|
||||
s := ctx.srv
|
||||
pctx := ctx.proxyCtx
|
||||
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
elapsed := time.Since(dctx.startTime)
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
shouldLog := true
|
||||
msg := pctx.Req
|
||||
@@ -26,54 +27,79 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||
shouldLog = false
|
||||
}
|
||||
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
ip = netutil.CloneIP(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
|
||||
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
|
||||
s.anonymizer.Load()(ip)
|
||||
|
||||
log.Debug("client ip: %s", ip)
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
|
||||
// uninitialized while in use. This can happen after proxy server has been
|
||||
// stopped, but its workers haven't yet exited.
|
||||
if shouldLog && s.queryLog != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
p := querylog.AddParams{
|
||||
Question: msg,
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: ip,
|
||||
ClientID: ctx.clientID,
|
||||
}
|
||||
|
||||
switch pctx.Proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
p.ClientProto = querylog.ClientProtoDoH
|
||||
case proxy.ProtoQUIC:
|
||||
p.ClientProto = querylog.ClientProtoDoQ
|
||||
case proxy.ProtoTLS:
|
||||
p.ClientProto = querylog.ClientProtoDoT
|
||||
case proxy.ProtoDNSCrypt:
|
||||
p.ClientProto = querylog.ClientProtoDNSCrypt
|
||||
default:
|
||||
// Consider this a plain DNS-over-UDP or DNS-over-TCP
|
||||
// request.
|
||||
}
|
||||
|
||||
if pctx.Upstream != nil {
|
||||
p.Upstream = pctx.Upstream.Address()
|
||||
}
|
||||
|
||||
s.queryLog.Add(p)
|
||||
s.logQuery(dctx, pctx, elapsed, ip)
|
||||
}
|
||||
|
||||
s.updateStats(ctx, elapsed, *ctx.result)
|
||||
if s.stats != nil {
|
||||
s.updateStats(dctx, elapsed, *dctx.result, ip)
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filtering.Result) {
|
||||
if s.stats == nil {
|
||||
return
|
||||
// logQuery pushes the request details into the query log.
|
||||
func (s *Server) logQuery(
|
||||
dctx *dnsContext,
|
||||
pctx *proxy.DNSContext,
|
||||
elapsed time.Duration,
|
||||
ip net.IP,
|
||||
) {
|
||||
p := &querylog.AddParams{
|
||||
Question: pctx.Req,
|
||||
ReqECS: pctx.ReqECS,
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: dctx.origResp,
|
||||
Result: dctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientID: dctx.clientID,
|
||||
ClientIP: ip,
|
||||
AuthenticatedData: dctx.responseAD,
|
||||
}
|
||||
|
||||
switch pctx.Proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
p.ClientProto = querylog.ClientProtoDoH
|
||||
case proxy.ProtoQUIC:
|
||||
p.ClientProto = querylog.ClientProtoDoQ
|
||||
case proxy.ProtoTLS:
|
||||
p.ClientProto = querylog.ClientProtoDoT
|
||||
case proxy.ProtoDNSCrypt:
|
||||
p.ClientProto = querylog.ClientProtoDNSCrypt
|
||||
default:
|
||||
// Consider this a plain DNS-over-UDP or DNS-over-TCP request.
|
||||
}
|
||||
|
||||
if pctx.Upstream != nil {
|
||||
p.Upstream = pctx.Upstream.Address()
|
||||
} else if cachedUps := pctx.CachedUpstreamAddr; cachedUps != "" {
|
||||
p.Upstream = pctx.CachedUpstreamAddr
|
||||
p.Cached = true
|
||||
}
|
||||
|
||||
s.queryLog.Add(p)
|
||||
}
|
||||
|
||||
// updatesStats writes the request into statistics.
|
||||
func (s *Server) updateStats(
|
||||
ctx *dnsContext,
|
||||
elapsed time.Duration,
|
||||
res filtering.Result,
|
||||
clientIP net.IP,
|
||||
) {
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{}
|
||||
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
|
||||
@@ -81,8 +107,8 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if ip, _ := netutil.IPAndPortFromAddr(pctx.Addr); ip != nil {
|
||||
e.Client = ip.String()
|
||||
} else if clientIP != nil {
|
||||
e.Client = clientIP.String()
|
||||
}
|
||||
|
||||
e.Time = uint32(elapsed / 1000)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
@@ -21,11 +22,11 @@ type testQueryLog struct {
|
||||
// a querylog.QueryLog without actually implementing all methods.
|
||||
querylog.QueryLog
|
||||
|
||||
lastParams querylog.AddParams
|
||||
lastParams *querylog.AddParams
|
||||
}
|
||||
|
||||
// Add implements the querylog.QueryLog interface for *testQueryLog.
|
||||
func (l *testQueryLog) Add(p querylog.AddParams) {
|
||||
func (l *testQueryLog) Add(p *querylog.AddParams) {
|
||||
l.lastParams = p
|
||||
}
|
||||
|
||||
@@ -65,7 +66,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
reason: filtering.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls_client_id",
|
||||
name: "success_tls_clientid",
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "cli42",
|
||||
@@ -157,9 +158,16 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
ql := &testQueryLog{}
|
||||
st := &testStats{}
|
||||
srv := &Server{
|
||||
queryLog: ql,
|
||||
stats: st,
|
||||
anonymizer: aghnet.NewIPMut(nil),
|
||||
}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
@@ -173,14 +181,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
Addr: tc.addr,
|
||||
Upstream: ups,
|
||||
}
|
||||
|
||||
ql := &testQueryLog{}
|
||||
st := &testStats{}
|
||||
dctx := &dnsContext{
|
||||
srv: &Server{
|
||||
queryLog: ql,
|
||||
stats: st,
|
||||
},
|
||||
proxyCtx: pctx,
|
||||
startTime: time.Now(),
|
||||
result: &filtering.Result{
|
||||
@@ -189,7 +190,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
clientID: tc.clientID,
|
||||
}
|
||||
|
||||
code := processQueryLogsAndStats(dctx)
|
||||
code := srv.processQueryLogsAndStats(dctx)
|
||||
assert.Equal(t, tc.wantCode, code)
|
||||
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
|
||||
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
|
||||
|
||||
@@ -520,6 +520,43 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_bad": {
|
||||
"req": {
|
||||
"local_ptr_upstreams": [
|
||||
"123.123.123.123",
|
||||
"[/non.arpa/]#"
|
||||
]
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_null": {
|
||||
"req": {
|
||||
"local_ptr_upstreams": null
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
@@ -239,7 +240,7 @@ func initBlockedServices() {
|
||||
for _, s := range serviceRulesArray {
|
||||
netRules := []*rules.NetworkRule{}
|
||||
for _, text := range s.rules {
|
||||
rule, err := rules.NewNetworkRule(text, 0)
|
||||
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
|
||||
if err != nil {
|
||||
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
|
||||
continue
|
||||
@@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
list := []string{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,14 +15,10 @@ type DNSRewriteResult struct {
|
||||
// the server returns.
|
||||
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
|
||||
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
|
||||
// an empty result if dnsr is empty. Otherwise, the result will have
|
||||
// either CanonName or DNSRewriteResult set.
|
||||
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
|
||||
// result if dnsr is empty. Otherwise, the result will have either CanonName or
|
||||
// DNSRewriteResult set. dnsr is expected to be non-empty.
|
||||
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
if len(dnsr) == 0 {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
var rules []*ResultRule
|
||||
dnsrr := &DNSRewriteResult{
|
||||
Response: DNSRewriteResultResponse{},
|
||||
@@ -31,8 +27,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
for _, nr := range dnsr {
|
||||
dr := nr.DNSRewrite
|
||||
if dr.NewCNAME != "" {
|
||||
// NewCNAME rules have a higher priority than
|
||||
// the other rules.
|
||||
// NewCNAME rules have a higher priority than other rules.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
@@ -54,8 +49,8 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
Text: nr.RuleText,
|
||||
})
|
||||
default:
|
||||
// RcodeRefused and other such codes have higher
|
||||
// priority. Return immediately.
|
||||
// RcodeRefused and other such codes have higher priority. Return
|
||||
// immediately.
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
||||
`
|
||||
|
||||
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
setts := &Settings{
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package filtering
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -24,6 +26,18 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
const (
|
||||
CustomListID = -iota
|
||||
SysHostsListID
|
||||
BlockedSvcsListID
|
||||
ParentalListID
|
||||
SafeBrowsingListID
|
||||
SafeSearchListID
|
||||
)
|
||||
|
||||
// ServiceEntry - blocked service array element
|
||||
type ServiceEntry struct {
|
||||
Name string
|
||||
@@ -38,6 +52,7 @@ type Settings struct {
|
||||
|
||||
ServicesRules []ServiceEntry
|
||||
|
||||
ProtectionEnabled bool
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
@@ -65,7 +80,7 @@ type Config struct {
|
||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
Rewrites []RewriteEntry `yaml:"rewrites"`
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
@@ -73,7 +88,7 @@ type Config struct {
|
||||
|
||||
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
||||
// system configuration files (e.g. /etc/hosts).
|
||||
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
|
||||
EtcHosts *aghnet.HostsContainer `yaml:"-"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
@@ -124,6 +139,10 @@ type DNSFilter struct {
|
||||
parentalUpstream upstream.Upstream
|
||||
safeBrowsingUpstream upstream.Upstream
|
||||
|
||||
safebrowsingCache cache.Cache
|
||||
parentalCache cache.Cache
|
||||
safeSearchCache cache.Cache
|
||||
|
||||
Config // for direct access by library users, even a = assignment
|
||||
// confLock protects Config.
|
||||
confLock sync.RWMutex
|
||||
@@ -142,9 +161,14 @@ type DNSFilter struct {
|
||||
|
||||
// Filter represents a filter list
|
||||
type Filter struct {
|
||||
ID int64 // auto-assigned when filter is added (see nextFilterID)
|
||||
Data []byte `yaml:"-"` // List of rules divided by '\n'
|
||||
FilePath string `yaml:"-"` // Path to a filtering rules file
|
||||
// FilePath is the path to a filtering rules list file.
|
||||
FilePath string `yaml:"-"`
|
||||
|
||||
// Data is the content of the file.
|
||||
Data []byte `yaml:"-"`
|
||||
|
||||
// ID is automatically assigned when filter is added using nextFilterID.
|
||||
ID int64
|
||||
}
|
||||
|
||||
// Reason holds an enum detailing why it was filtered or not filtered
|
||||
@@ -176,8 +200,8 @@ const (
|
||||
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
||||
FilteredBlockedService
|
||||
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS
|
||||
// rewrite rule.
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
|
||||
// rule.
|
||||
Rewritten
|
||||
|
||||
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
||||
@@ -186,8 +210,8 @@ const (
|
||||
|
||||
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
||||
//
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
|
||||
// their functionality into RewrittenRule.
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
|
||||
// functionality into RewrittenRule.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
||||
RewrittenRule
|
||||
@@ -221,12 +245,13 @@ func (r Reason) String() string {
|
||||
}
|
||||
|
||||
// In returns true if reasons include r.
|
||||
func (r Reason) In(reasons ...Reason) bool {
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) {
|
||||
for _, reason := range reasons {
|
||||
if r == reason {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,7 +270,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
return Settings{
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1,
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
|
||||
SafeSearchEnabled: d.Config.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
|
||||
ParentalEnabled: d.Config.ParentalEnabled,
|
||||
@@ -261,8 +286,14 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
}
|
||||
|
||||
func cloneRewrites(entries []RewriteEntry) (clone []RewriteEntry) {
|
||||
return append([]RewriteEntry(nil), entries...)
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
|
||||
clone = make([]*LegacyRewrite, len(entries))
|
||||
for i, rw := range entries {
|
||||
clone[i] = rw.clone()
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// SetFilters - set new filters (synchronously or asynchronously)
|
||||
@@ -338,14 +369,6 @@ func (d *DNSFilter) reset() {
|
||||
}
|
||||
}
|
||||
|
||||
type dnsFilterContext struct {
|
||||
safebrowsingCache cache.Cache
|
||||
parentalCache cache.Cache
|
||||
safeSearchCache cache.Cache
|
||||
}
|
||||
|
||||
var gctx dnsFilterContext
|
||||
|
||||
// ResultRule contains information about applied rules.
|
||||
type ResultRule struct {
|
||||
// Text is the text of the rule.
|
||||
@@ -371,24 +394,19 @@ type Result struct {
|
||||
// Reason is the reason for blocking or unblocking the request.
|
||||
Reason Reason `json:",omitempty"`
|
||||
|
||||
// Rules are applied rules. If Rules are not empty, each rule
|
||||
// is not nil.
|
||||
// Rules are applied rules. If Rules are not empty, each rule is not nil.
|
||||
Rules []*ResultRule `json:",omitempty"`
|
||||
|
||||
// ReverseHosts is the reverse lookup rewrite result. It is
|
||||
// empty unless Reason is set to RewrittenAutoHosts.
|
||||
ReverseHosts []string `json:",omitempty"`
|
||||
|
||||
// IPList is the lookup rewrite result. It is empty unless
|
||||
// Reason is set to RewrittenAutoHosts or Rewritten.
|
||||
// IPList is the lookup rewrite result. It is empty unless Reason is set to
|
||||
// Rewritten.
|
||||
IPList []net.IP `json:",omitempty"`
|
||||
|
||||
// CanonName is the CNAME value from the lookup rewrite result.
|
||||
// It is empty unless Reason is set to Rewritten or RewrittenRule.
|
||||
// CanonName is the CNAME value from the lookup rewrite result. It is empty
|
||||
// unless Reason is set to Rewritten or RewrittenRule.
|
||||
CanonName string `json:",omitempty"`
|
||||
|
||||
// ServiceName is the name of the blocked service. It is empty
|
||||
// unless Reason is set to FilteredBlockedService.
|
||||
// ServiceName is the name of the blocked service. It is empty unless
|
||||
// Reason is set to FilteredBlockedService.
|
||||
ServiceName string `json:",omitempty"`
|
||||
|
||||
// DNSRewriteResult is the $dnsrewrite filter rule result.
|
||||
@@ -402,14 +420,8 @@ func (r Reason) Matched() bool {
|
||||
}
|
||||
|
||||
// CheckHostRules tries to match the host against filtering rules only.
|
||||
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) {
|
||||
if !setts.FilteringEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
host = strings.ToLower(host)
|
||||
|
||||
return d.matchHost(host, qtype, setts)
|
||||
func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
|
||||
return d.matchHost(strings.ToLower(host), rrtype, setts)
|
||||
}
|
||||
|
||||
// CheckHost tries to match the host against filtering rules, then safebrowsing
|
||||
@@ -422,14 +434,16 @@ func (d *DNSFilter) CheckHost(
|
||||
// Sometimes clients try to resolve ".", which is a request to get root
|
||||
// servers.
|
||||
if host == "" {
|
||||
return Result{Reason: NotFilteredNotFound}, nil
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
host = strings.ToLower(host)
|
||||
|
||||
res = d.processRewrites(host, qtype)
|
||||
if res.Reason == Rewritten {
|
||||
return res, nil
|
||||
if setts.FilteringEnabled {
|
||||
res = d.processRewrites(host, qtype)
|
||||
if res.Reason == Rewritten {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, hc := range d.hostCheckers {
|
||||
@@ -446,100 +460,123 @@ func (d *DNSFilter) CheckHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
// checkEtcHosts compares the host against our /etc/hosts table. The err is
|
||||
// always nil, it is only there to make this a valid hostChecker function.
|
||||
func (d *DNSFilter) checkEtcHosts(
|
||||
// matchSysHosts tries to match the host against the operating system's hosts
|
||||
// database. err is always nil.
|
||||
func (d *DNSFilter) matchSysHosts(
|
||||
host string,
|
||||
qtype uint16,
|
||||
_ *Settings,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if d.Config.EtcHosts == nil {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
ips := d.Config.EtcHosts.Process(host, qtype)
|
||||
if ips != nil {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
IPList: ips,
|
||||
}
|
||||
|
||||
if !setts.FilteringEnabled || d.EtcHosts == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
|
||||
if len(revHosts) != 0 {
|
||||
res = Result{
|
||||
Reason: RewrittenAutoHosts,
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Optimize this with a buffer.
|
||||
res.ReverseHosts = make([]string, len(revHosts))
|
||||
for i := range revHosts {
|
||||
res.ReverseHosts[i] = revHosts[i] + "."
|
||||
}
|
||||
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
})
|
||||
if dnsres == nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return Result{}, nil
|
||||
dnsr := dnsres.DNSRewrites()
|
||||
if len(dnsr) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
res.Reason = RewrittenAutoHosts
|
||||
for _, r := range res.Rules {
|
||||
r.Text = stringutil.Coalesce(d.EtcHosts.Translate(r.Text), r.Text)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Process rewrites table
|
||||
// . Find CNAME for a domain name (exact match or by wildcard)
|
||||
// . if found and CNAME equals to domain name - this is an exception; exit
|
||||
// . if found, set domain name to canonical name
|
||||
// . repeat for the new domain name (Note: we return only the last CNAME)
|
||||
// . Find A or AAAA record for a domain name (exact match or by wildcard)
|
||||
// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array
|
||||
// processRewrites performs filtering based on the legacy rewrite records.
|
||||
//
|
||||
// Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host,
|
||||
// this query isn't filtered. If it's different, repeat the process for the new
|
||||
// CNAME, breaking loops in the process.
|
||||
//
|
||||
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
|
||||
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
|
||||
// result is an exception.
|
||||
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
rr := findRewrites(d.Rewrites, host, qtype)
|
||||
if len(rr) != 0 {
|
||||
res.Reason = Rewritten
|
||||
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
|
||||
if !matched {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
|
||||
cnames := stringutil.NewSet()
|
||||
origHost := host
|
||||
for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
|
||||
log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer)
|
||||
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
|
||||
rw := rewrites[0]
|
||||
rwPat := rw.Domain
|
||||
rwAns := rw.Answer
|
||||
|
||||
if host == rr[0].Answer { // "host == CNAME" is an exception
|
||||
res.Reason = NotFilteredNotFound
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
return res
|
||||
if origHost == rwAns || rwPat == rwAns {
|
||||
// Either a request for the hostname itself or a rewrite of
|
||||
// a pattern onto itself, both of which are an exception rules.
|
||||
// Return a not filtered result.
|
||||
return Result{}
|
||||
} else if host == rwAns && isWildcard(rwPat) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
|
||||
res.CanonName = host
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
host = rr[0].Answer
|
||||
host = rwAns
|
||||
if cnames.Has(host) {
|
||||
log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
|
||||
log.Info("rewrite: cname loop for %q on %q", origHost, host)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
cnames.Add(host)
|
||||
res.CanonName = rr[0].Answer
|
||||
rr = findRewrites(d.Rewrites, host, qtype)
|
||||
res.CanonName = host
|
||||
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
|
||||
}
|
||||
|
||||
for _, r := range rr {
|
||||
if r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if r.IP == nil { // IP exception
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, r.IP)
|
||||
log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP)
|
||||
}
|
||||
}
|
||||
setRewriteResult(&res, host, rewrites, qtype)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
|
||||
// be nil.
|
||||
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
|
||||
for _, rw := range rewrites {
|
||||
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if rw.IP == nil {
|
||||
// "A"/"AAAA" exception: allow getting from upstream.
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, rw.IP)
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchBlockedServicesRules checks the host against the blocked services rules
|
||||
// in settings, if any. The err is always nil, it is only there to make this
|
||||
// a valid hostChecker function.
|
||||
@@ -548,6 +585,10 @@ func matchBlockedServicesRules(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ProtectionEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
svcs := setts.ServicesRules
|
||||
if len(svcs) == 0 {
|
||||
return Result{}, nil
|
||||
@@ -582,80 +623,82 @@ func matchBlockedServicesRules(
|
||||
// Adding rule and matching against the rules
|
||||
//
|
||||
|
||||
// fileExists returns true if file exists.
|
||||
func fileExists(fn string) bool {
|
||||
_, err := os.Stat(fn)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func createFilteringEngine(filters []Filter) (*filterlist.RuleStorage, *urlfilter.DNSEngine, error) {
|
||||
listArray := []filterlist.RuleList{}
|
||||
func newRuleStorage(filters []Filter) (rs *filterlist.RuleStorage, err error) {
|
||||
lists := make([]filterlist.RuleList, 0, len(filters))
|
||||
for _, f := range filters {
|
||||
var list filterlist.RuleList
|
||||
|
||||
if f.ID == 0 {
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: 0,
|
||||
switch id := int(f.ID); {
|
||||
case len(f.Data) != 0:
|
||||
lists = append(lists, &filterlist.StringRuleList{
|
||||
ID: id,
|
||||
RulesText: string(f.Data),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
} else if !fileExists(f.FilePath) {
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: int(f.ID),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
} else if runtime.GOOS == "windows" {
|
||||
// On Windows we don't pass a file to urlfilter because
|
||||
// it's difficult to update this file while it's being
|
||||
// used.
|
||||
data, err := os.ReadFile(f.FilePath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading filter content: %w", err)
|
||||
})
|
||||
case f.FilePath == "":
|
||||
continue
|
||||
case runtime.GOOS == "windows":
|
||||
// On Windows we don't pass a file to urlfilter because it's
|
||||
// difficult to update this file while it's being used.
|
||||
var data []byte
|
||||
data, err = os.ReadFile(f.FilePath)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("reading filter content: %w", err)
|
||||
}
|
||||
|
||||
list = &filterlist.StringRuleList{
|
||||
ID: int(f.ID),
|
||||
lists = append(lists, &filterlist.StringRuleList{
|
||||
ID: id,
|
||||
RulesText: string(data),
|
||||
IgnoreCosmetic: true,
|
||||
})
|
||||
default:
|
||||
var list *filterlist.FileRuleList
|
||||
list, err = filterlist.NewFileRuleList(id, f.FilePath, true)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("creating file rule list with %q: %w", f.FilePath, err)
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
list, err = filterlist.NewFileRuleList(int(f.ID), f.FilePath, true)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("filterlist.NewFileRuleList(): %s: %w", f.FilePath, err)
|
||||
}
|
||||
|
||||
lists = append(lists, list)
|
||||
}
|
||||
listArray = append(listArray, list)
|
||||
}
|
||||
|
||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
||||
rs, err = filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
||||
return nil, fmt.Errorf("creating rule storage: %w", err)
|
||||
}
|
||||
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
|
||||
return rulesStorage, filteringEngine, nil
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// Initialize urlfilter objects.
|
||||
func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
|
||||
rulesStorage, filteringEngine, err := createFilteringEngine(blockFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rulesStorageAllow, filteringEngineAllow, err := createFilteringEngine(allowFilters)
|
||||
rulesStorage, err := newRuleStorage(blockFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.engineLock.Lock()
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
d.engineLock.Unlock()
|
||||
rulesStorageAllow, err := newRuleStorage(allowFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible
|
||||
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
|
||||
filteringEngineAllow := urlfilter.NewDNSEngine(rulesStorageAllow)
|
||||
|
||||
func() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
d.reset()
|
||||
d.rulesStorage = rulesStorage
|
||||
d.filteringEngine = filteringEngine
|
||||
d.rulesStorageAllow = rulesStorageAllow
|
||||
d.filteringEngineAllow = filteringEngineAllow
|
||||
}()
|
||||
|
||||
// Make sure that the OS reclaims memory as soon as possible.
|
||||
debug.FreeOSMemory()
|
||||
log.Debug("initialized filtering engine")
|
||||
|
||||
@@ -677,11 +720,10 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
|
||||
return res
|
||||
}
|
||||
|
||||
// matchHostProcessAllowList processes the allowlist logic of host
|
||||
// matching.
|
||||
// matchHostProcessAllowList processes the allowlist logic of host matching.
|
||||
func (d *DNSFilter) matchHostProcessAllowList(
|
||||
host string,
|
||||
dnsres urlfilter.DNSResult,
|
||||
dnsres *urlfilter.DNSResult,
|
||||
) (res Result, err error) {
|
||||
var matchedRules []rules.Rule
|
||||
if dnsres.NetworkRule != nil {
|
||||
@@ -704,7 +746,7 @@ func (d *DNSFilter) matchHostProcessAllowList(
|
||||
// matchHostProcessDNSResult processes the matched DNS filtering result.
|
||||
func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
qtype uint16,
|
||||
dnsres urlfilter.DNSResult,
|
||||
dnsres *urlfilter.DNSResult,
|
||||
) (res Result) {
|
||||
if dnsres.NetworkRule != nil {
|
||||
reason := FilteredBlockList
|
||||
@@ -734,8 +776,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
}
|
||||
|
||||
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
|
||||
// Question type doesn't match the host rules. Return the first
|
||||
// matched host rule, but without an IP address.
|
||||
// Question type doesn't match the host rules. Return the first matched
|
||||
// host rule, but without an IP address.
|
||||
var matchedRules []rules.Rule
|
||||
if dnsres.HostRulesV4 != nil {
|
||||
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
|
||||
@@ -749,32 +791,34 @@ func (d *DNSFilter) matchHostProcessDNSResult(
|
||||
return Result{}
|
||||
}
|
||||
|
||||
// matchHost is a low-level way to check only if hostname is filtered by rules,
|
||||
// matchHost is a low-level way to check only if host is filtered by rules,
|
||||
// skipping expensive safebrowsing and parental lookups.
|
||||
func (d *DNSFilter) matchHost(
|
||||
host string,
|
||||
qtype uint16,
|
||||
rrtype uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.FilteringEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match()
|
||||
// but also while using the rules returned by it.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
ureq := urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
DNSType: rrtype,
|
||||
}
|
||||
|
||||
if d.filteringEngineAllow != nil {
|
||||
d.engineLock.RLock()
|
||||
// Keep in mind that this lock must be held no just when calling Match() but
|
||||
// also while using the rules returned by it.
|
||||
//
|
||||
// TODO(e.burkov): Inspect if the above is true.
|
||||
defer d.engineLock.RUnlock()
|
||||
|
||||
if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
|
||||
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
|
||||
if ok {
|
||||
return d.matchHostProcessAllowList(host, dnsres)
|
||||
@@ -786,13 +830,12 @@ func (d *DNSFilter) matchHost(
|
||||
}
|
||||
|
||||
dnsres, ok := d.filteringEngine.MatchRequest(ureq)
|
||||
|
||||
// Check DNS rewrites first, because the API there is a bit awkward.
|
||||
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
if res.Reason == RewrittenRule && res.CanonName == host {
|
||||
// A rewrite of a host to itself. Go on and try
|
||||
// matching other things.
|
||||
// A rewrite of a host to itself. Go on and try matching other
|
||||
// things.
|
||||
} else {
|
||||
return res, nil
|
||||
}
|
||||
@@ -800,7 +843,12 @@ func (d *DNSFilter) matchHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
res = d.matchHostProcessDNSResult(qtype, dnsres)
|
||||
if !setts.ProtectionEnabled {
|
||||
// Don't check non-dnsrewrite filtering results.
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
res = d.matchHostProcessDNSResult(rrtype, dnsres)
|
||||
for _, r := range res.Rules {
|
||||
log.Debug(
|
||||
"filtering: found rule %q for host %q, filter list id: %d",
|
||||
@@ -836,40 +884,33 @@ func InitModule() {
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used.
|
||||
func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
var resolver Resolver = net.DefaultResolver
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
}
|
||||
if c != nil {
|
||||
cacheConf := cache.Config{
|
||||
|
||||
d.safebrowsingCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
}
|
||||
|
||||
if gctx.safebrowsingCache == nil {
|
||||
cacheConf.MaxSize = c.SafeBrowsingCacheSize
|
||||
gctx.safebrowsingCache = cache.New(cacheConf)
|
||||
}
|
||||
|
||||
if gctx.safeSearchCache == nil {
|
||||
cacheConf.MaxSize = c.SafeSearchCacheSize
|
||||
gctx.safeSearchCache = cache.New(cacheConf)
|
||||
}
|
||||
|
||||
if gctx.parentalCache == nil {
|
||||
cacheConf.MaxSize = c.ParentalCacheSize
|
||||
gctx.parentalCache = cache.New(cacheConf)
|
||||
}
|
||||
MaxSize: c.SafeBrowsingCacheSize,
|
||||
})
|
||||
d.safeSearchCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeSearchCacheSize,
|
||||
})
|
||||
d.parentalCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if c.CustomResolver != nil {
|
||||
resolver = c.CustomResolver
|
||||
d.resolver = c.CustomResolver
|
||||
}
|
||||
}
|
||||
|
||||
d := &DNSFilter{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
check: d.checkEtcHosts,
|
||||
name: "etchosts",
|
||||
check: d.matchSysHosts,
|
||||
name: "hosts container",
|
||||
}, {
|
||||
check: d.matchHost,
|
||||
name: "filtering",
|
||||
@@ -890,12 +931,18 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
err := d.initSecurityServices()
|
||||
if err != nil {
|
||||
log.Error("filtering: initialize services: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
d.Config = *c
|
||||
d.prepareRewrites()
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
log.Error("rewrites: preparing: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
|
||||
@@ -21,15 +21,17 @@ func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
var setts Settings
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
|
||||
// Helpers.
|
||||
|
||||
func purgeCaches() {
|
||||
func purgeCaches(d *DNSFilter) {
|
||||
for _, c := range []cache.Cache{
|
||||
gctx.safebrowsingCache,
|
||||
gctx.parentalCache,
|
||||
gctx.safeSearchCache,
|
||||
d.safebrowsingCache,
|
||||
d.parentalCache,
|
||||
d.safeSearchCache,
|
||||
} {
|
||||
if c != nil {
|
||||
c.Clear()
|
||||
@@ -37,11 +39,11 @@ func purgeCaches() {
|
||||
}
|
||||
}
|
||||
|
||||
func newForTest(c *Config, filters []Filter) *DNSFilter {
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||
setts = Settings{
|
||||
FilteringEnabled: true,
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
setts.FilteringEnabled = true
|
||||
if c != nil {
|
||||
c.SafeBrowsingCacheSize = 10000
|
||||
c.ParentalCacheSize = 10000
|
||||
@@ -52,7 +54,8 @@ func newForTest(c *Config, filters []Filter) *DNSFilter {
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
d := New(c, filters)
|
||||
purgeCaches()
|
||||
purgeCaches(d)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -103,7 +106,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
filters := []Filter{{
|
||||
ID: 0, Data: []byte(text),
|
||||
}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
||||
@@ -168,7 +171,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -191,7 +194,7 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -215,7 +218,7 @@ func TestParallelSB(t *testing.T) {
|
||||
// Safe Search.
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(t, ok)
|
||||
@@ -224,7 +227,9 @@ func TestSafeSearch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
@@ -247,13 +252,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, yandexIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(&Config{
|
||||
d := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
@@ -280,12 +286,13 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, ip, res.Rules[0].IP)
|
||||
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const domain = "yandex.ru"
|
||||
|
||||
@@ -299,7 +306,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
@@ -310,7 +317,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
assert.Equal(t, res.Rules[0].IP, yandexIP)
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
@@ -319,7 +326,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(&Config{
|
||||
d := newForTest(t, &Config{
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
@@ -332,7 +339,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
d.resolver = resolver
|
||||
|
||||
@@ -359,7 +366,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
assert.True(t, res.Rules[0].IP.Equal(ip))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
@@ -373,7 +380,7 @@ func TestParentalControl(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(&Config{ParentalEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -677,7 +684,7 @@ func TestMatching(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||
@@ -703,7 +710,7 @@ func TestWhitelist(t *testing.T) {
|
||||
whiteFilters := []Filter{{
|
||||
ID: 0, Data: []byte(whiteRules),
|
||||
}}
|
||||
d := newForTest(nil, filters)
|
||||
d := newForTest(t, nil, filters)
|
||||
|
||||
err := d.SetFilters(filters, whiteFilters, false)
|
||||
require.NoError(t, err)
|
||||
@@ -748,7 +755,7 @@ func applyClientSettings(setts *Settings) {
|
||||
}
|
||||
|
||||
func TestClientSettings(t *testing.T) {
|
||||
d := newForTest(
|
||||
d := newForTest(t,
|
||||
&Config{
|
||||
ParentalEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
@@ -797,7 +804,11 @@ func TestClientSettings(t *testing.T) {
|
||||
|
||||
makeTester := func(tc testCase, before bool) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
t.Helper()
|
||||
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if before {
|
||||
assert.True(t, r.IsFiltered)
|
||||
assert.Equal(t, tc.wantReason, r.Reason)
|
||||
@@ -808,7 +819,7 @@ func TestClientSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check behaviour without any per-client settings, then apply per-client
|
||||
// settings and check behaviour once again.
|
||||
// settings and check behavior once again.
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, makeTester(tc, tc.before))
|
||||
}
|
||||
@@ -823,7 +834,7 @@ func TestClientSettings(t *testing.T) {
|
||||
// Benchmarks.
|
||||
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -839,7 +850,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
@@ -857,7 +868,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
@@ -868,7 +879,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
|
||||
@@ -4,93 +4,121 @@ package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RewriteEntry is a rewrite array element
|
||||
type RewriteEntry struct {
|
||||
// Domain is the domain for which this rewrite should work.
|
||||
// LegacyRewrite is a single legacy DNS rewrite record.
|
||||
//
|
||||
// Instances of *LegacyRewrite must never be nil.
|
||||
type LegacyRewrite struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
|
||||
// IP is the IP address that should be used in the response if Type is
|
||||
// A or AAAA.
|
||||
// dns.TypeA or dns.TypeAAAA.
|
||||
IP net.IP `yaml:"-"`
|
||||
|
||||
// Type is the DNS record type: A, AAAA, or CNAME.
|
||||
Type uint16 `yaml:"-"`
|
||||
}
|
||||
|
||||
// equal returns true if the entry is considered equal to the other.
|
||||
func (e *RewriteEntry) equal(other RewriteEntry) (ok bool) {
|
||||
return e.Domain == other.Domain && e.Answer == other.Answer
|
||||
// clone returns a deep clone of rw.
|
||||
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
|
||||
return &LegacyRewrite{
|
||||
Domain: rw.Domain,
|
||||
Answer: rw.Answer,
|
||||
IP: netutil.CloneIP(rw.IP),
|
||||
Type: rw.Type,
|
||||
}
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matched qtype.
|
||||
func (e *RewriteEntry) matchesQType(qtype uint16) (ok bool) {
|
||||
// equal returns true if the rw is equal to the other.
|
||||
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
|
||||
return rw.Domain == other.Domain && rw.Answer == other.Answer
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matches the question type qt.
|
||||
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if e.Type == dns.TypeCNAME {
|
||||
if rw.Type == dns.TypeCNAME {
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the types match or the entry is set to allow only the other type,
|
||||
// include them.
|
||||
return e.Type == qtype || e.IP == nil
|
||||
return rw.Type == qt || rw.IP == nil
|
||||
}
|
||||
|
||||
// normalize makes sure that the a new or decoded entry is normalized with
|
||||
// regards to domain name case, IP length, and so on.
|
||||
func (e *RewriteEntry) normalize() {
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix
|
||||
// and use it in matchDomainWildcard instead of using strings.ToLower
|
||||
//
|
||||
// If rw is nil, it returns an errors.
|
||||
func (rw *LegacyRewrite) normalize() (err error) {
|
||||
if rw == nil {
|
||||
return errors.Error("nil rewrite entry")
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
|
||||
// use it in matchDomainWildcard instead of using strings.ToLower
|
||||
// everywhere.
|
||||
e.Domain = strings.ToLower(e.Domain)
|
||||
rw.Domain = strings.ToLower(rw.Domain)
|
||||
|
||||
switch e.Answer {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
e.IP = nil
|
||||
e.Type = dns.TypeAAAA
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeAAAA
|
||||
|
||||
return
|
||||
return nil
|
||||
case "A":
|
||||
e.IP = nil
|
||||
e.Type = dns.TypeA
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeA
|
||||
|
||||
return
|
||||
return nil
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
ip := net.ParseIP(e.Answer)
|
||||
ip := net.ParseIP(rw.Answer)
|
||||
if ip == nil {
|
||||
e.Type = dns.TypeCNAME
|
||||
rw.Type = dns.TypeCNAME
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
e.IP = ip4
|
||||
e.Type = dns.TypeA
|
||||
rw.IP = ip4
|
||||
rw.Type = dns.TypeA
|
||||
} else {
|
||||
e.IP = ip
|
||||
e.Type = dns.TypeAAAA
|
||||
rw.IP = ip
|
||||
rw.Type = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isWildcard(host string) bool {
|
||||
return len(host) > 1 && host[0] == '*' && host[1] == '.'
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) bool {
|
||||
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
|
||||
}
|
||||
|
||||
// matchDomainWildcard returns true if host matches the wildcard pattern.
|
||||
@@ -106,16 +134,16 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
// wildcard > exact
|
||||
// lower level wildcard > higher level wildcard
|
||||
//
|
||||
type rewritesSorted []RewriteEntry
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Len() int { return len(a) }
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for legacyRewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) bool {
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
|
||||
@@ -132,49 +160,62 @@ func (a rewritesSorted) Less(i, j int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// both are wildcards
|
||||
// Both are wildcards.
|
||||
return len(a[i].Domain) > len(a[j].Domain)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) prepareRewrites() {
|
||||
for i := range d.Rewrites {
|
||||
d.Rewrites[i].normalize()
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
func (d *DNSFilter) prepareRewrites() (err error) {
|
||||
for i, r := range d.Rewrites {
|
||||
err = r.normalize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findRewrites returns the list of matched rewrite entries. The priority is:
|
||||
// CNAME, then A and AAAA; exact, then wildcard. If the host is matched
|
||||
// exactly, wildcard entries aren't returned. If the host matched by wildcards,
|
||||
// return the most specific for the question type.
|
||||
func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) {
|
||||
rr := rewritesSorted{}
|
||||
// findRewrites returns the list of matched rewrite entries. If rewrites are
|
||||
// empty, but matched is true, the domain is found among the rewrite rules but
|
||||
// not for this question type.
|
||||
//
|
||||
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
|
||||
// host is matched exactly, wildcard entries aren't returned. If the host
|
||||
// matched by wildcards, return the most specific for the question type.
|
||||
func findRewrites(
|
||||
entries []*LegacyRewrite,
|
||||
host string,
|
||||
qtype uint16,
|
||||
) (rewrites []*LegacyRewrite, matched bool) {
|
||||
for _, e := range entries {
|
||||
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if e.matchesQType(qtype) {
|
||||
rr = append(rr, e)
|
||||
rewrites = append(rewrites, e)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rr) == 0 {
|
||||
return nil
|
||||
if len(rewrites) == 0 {
|
||||
return nil, matched
|
||||
}
|
||||
|
||||
sort.Sort(rr)
|
||||
sort.Sort(rewritesSorted(rewrites))
|
||||
|
||||
for i, r := range rr {
|
||||
for i, r := range rewrites {
|
||||
if isWildcard(r.Domain) {
|
||||
// Don't use rr[:0], because we need to return at least
|
||||
// one item here.
|
||||
rr = rr[:max(1, i)]
|
||||
// Don't use rewrites[:0], because we need to return at least one
|
||||
// item here.
|
||||
rewrites = rewrites[:max(1, i)]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rr
|
||||
return rewrites, matched
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
@@ -206,29 +247,39 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(arr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
rwJSON := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&rwJSON)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ent := RewriteEntry{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
rw := &LegacyRewrite{
|
||||
Domain: rwJSON.Domain,
|
||||
Answer: rwJSON.Answer,
|
||||
}
|
||||
ent.normalize()
|
||||
|
||||
err = rw.normalize()
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d.confLock.Lock()
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, ent)
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, rw)
|
||||
d.confLock.Unlock()
|
||||
log.Debug("Rewrites: added element: %s -> %s [%d]",
|
||||
ent.Domain, ent.Answer, len(d.Config.Rewrites))
|
||||
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
@@ -237,21 +288,25 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entDel := RewriteEntry{
|
||||
entDel := &LegacyRewrite{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
}
|
||||
arr := []RewriteEntry{}
|
||||
arr := []*LegacyRewrite{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
if ent.equal(entDel) {
|
||||
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
arr = append(arr, ent)
|
||||
}
|
||||
d.Config.Rewrites = arr
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
@@ -66,107 +66,132 @@ func TestRewrites(t *testing.T) {
|
||||
}, {
|
||||
Domain: "BIGHOST.COM",
|
||||
Answer: "1.2.3.7",
|
||||
}, {
|
||||
Domain: "*.issue4016.com",
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantVals []net.IP
|
||||
dtyp uint16
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantIPs []net.IP
|
||||
wantReason Reason
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantVals: nil,
|
||||
dtyp: dns.TypeA,
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantVals: []net.IP{{0, 0, 0, 0}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantIPs: []net.IP{{0, 0, 0, 0}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantVals: []net.IP{{1, 2, 3, 5}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{net.ParseIP("1234::5678")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{net.ParseIP("1234::5678")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantVals: []net.IP{{1, 2, 3, 7}},
|
||||
dtyp: dns.TypeA,
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 7}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
wantCName: "sub.issue4016.com",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4016_self",
|
||||
host: "sub.issue4016.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
valsNum := len(tc.wantVals)
|
||||
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if valsNum == 0 {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
|
||||
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
|
||||
|
||||
if tc.wantCName != "" {
|
||||
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||
}
|
||||
|
||||
require.Len(t, r.IPList, valsNum)
|
||||
for i, ip := range tc.wantVals {
|
||||
assert.Equal(t, ip, r.IPList[i])
|
||||
}
|
||||
assert.Equal(t, tc.wantIPs, r.IPList)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
Type: dns.TypeA,
|
||||
@@ -179,7 +204,8 @@ func TestRewritesLevels(t *testing.T) {
|
||||
Answer: "3.3.3.3",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -209,10 +235,10 @@ func TestRewritesLevels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
@@ -222,29 +248,32 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "match_sub-domain",
|
||||
name: "match_subdomain",
|
||||
host: "my.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "exception_cname",
|
||||
host: "sub.host.com",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -257,10 +286,10 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := newForTest(nil, nil)
|
||||
d := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []RewriteEntry{{
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
@@ -281,7 +310,8 @@ func TestRewritesExceptionIP(t *testing.T) {
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
d.prepareRewrites()
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -306,7 +307,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.SafeBrowsingEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -318,7 +319,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
sctx := &sbCtx{
|
||||
host: host,
|
||||
svc: "SafeBrowsing",
|
||||
cache: gctx.safebrowsingCache,
|
||||
cache: d.safebrowsingCache,
|
||||
cacheTime: d.Config.CacheTime,
|
||||
}
|
||||
|
||||
@@ -326,7 +327,8 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
IsFiltered: true,
|
||||
Reason: FilteredSafeBrowsing,
|
||||
Rules: []*ResultRule{{
|
||||
Text: "adguard-malware-shavar",
|
||||
Text: "adguard-malware-shavar",
|
||||
FilterListID: SafeBrowsingListID,
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -339,7 +341,7 @@ func (d *DNSFilter) checkParental(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.ParentalEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -351,7 +353,7 @@ func (d *DNSFilter) checkParental(
|
||||
sctx := &sbCtx{
|
||||
host: host,
|
||||
svc: "Parental",
|
||||
cache: gctx.parentalCache,
|
||||
cache: d.parentalCache,
|
||||
cacheTime: d.Config.CacheTime,
|
||||
}
|
||||
|
||||
@@ -359,19 +361,14 @@ func (d *DNSFilter) checkParental(
|
||||
IsFiltered: true,
|
||||
Reason: FilteredParental,
|
||||
Rules: []*ResultRule{{
|
||||
Text: "parental CATEGORY_BLACKLISTED",
|
||||
Text: "parental CATEGORY_BLACKLISTED",
|
||||
FilterListID: ParentalListID,
|
||||
}},
|
||||
}
|
||||
|
||||
return check(sctx, res, d.parentalUpstream)
|
||||
}
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
|
||||
d.Config.SafeBrowsingEnabled = true
|
||||
d.Config.ConfigModified()
|
||||
@@ -390,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
|
||||
Enabled: d.Config.SafeBrowsingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -413,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||
Enabled: d.Config.ParentalEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
@@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
setts := &Settings{
|
||||
ProtectionEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
}
|
||||
@@ -129,41 +130,42 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const hostname = "example.org"
|
||||
|
||||
setts := &Settings{
|
||||
ProtectionEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
testCache cache.Cache
|
||||
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
|
||||
name string
|
||||
block bool
|
||||
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
|
||||
testCache cache.Cache
|
||||
}{{
|
||||
testCache: d.safebrowsingCache,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
name: "sb_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
testCache: d.safebrowsingCache,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
name: "sb_block",
|
||||
block: true,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
testCache: d.parentalCache,
|
||||
testFunc: d.checkParental,
|
||||
name: "pc_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}, {
|
||||
testCache: d.parentalCache,
|
||||
testFunc: d.checkParental,
|
||||
name: "pc_block",
|
||||
block: true,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -215,6 +217,6 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
})
|
||||
|
||||
purgeCaches()
|
||||
purgeCaches(d)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
@@ -74,7 +75,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
_ uint16,
|
||||
setts *Settings,
|
||||
) (res Result, err error) {
|
||||
if !setts.SafeSearchEnabled {
|
||||
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
@@ -84,7 +85,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
}
|
||||
|
||||
// Check cache. Return cached result if it was found
|
||||
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
|
||||
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
|
||||
if isFound {
|
||||
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
|
||||
log.Tracef("SafeSearch: found in cache: %s", host)
|
||||
@@ -99,12 +100,14 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
res = Result{
|
||||
IsFiltered: true,
|
||||
Reason: FilteredSafeSearch,
|
||||
Rules: []*ResultRule{{}},
|
||||
Rules: []*ResultRule{{
|
||||
FilterListID: SafeSearchListID,
|
||||
}},
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(safeHost); ip != nil {
|
||||
res.Rules[0].IP = ip
|
||||
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
|
||||
valLen := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
|
||||
|
||||
return res, nil
|
||||
@@ -123,7 +126,7 @@ func (d *DNSFilter) checkSafeSearch(
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
|
||||
l := d.setCacheResult(gctx.safeSearchCache, host, res)
|
||||
l := d.setCacheResult(d.safeSearchCache, host, res)
|
||||
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
|
||||
|
||||
return res, nil
|
||||
@@ -150,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
|
||||
Enabled: d.Config.SafeSearchEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to write response json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -403,8 +404,8 @@ func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// When everything else fails, just return the remote address as
|
||||
// understood by the stdlib.
|
||||
// When everything else fails, just return the remote address as understood
|
||||
// by the stdlib.
|
||||
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
||||
@@ -417,19 +418,20 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
req := loginJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var remoteAddr string
|
||||
// The realIP couldn't be used here due to security issues.
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
//
|
||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||
if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -437,12 +439,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if blocker := Context.auth.blocker; blocker != nil {
|
||||
if left := blocker.check(remoteAddr); left > 0 {
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
|
||||
httpError(
|
||||
w,
|
||||
http.StatusTooManyRequests,
|
||||
"auth: blocked for %s",
|
||||
left,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -451,22 +448,23 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var cookie string
|
||||
cookie, err = Context.auth.httpCookie(req, remoteAddr)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "crypto rand reader: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Use realIP here, since this IP address is only used for logging.
|
||||
ip, err := realIP(r)
|
||||
if err != nil {
|
||||
log.Error("auth: getting real ip from request: %s", err)
|
||||
} else if ip == nil {
|
||||
// Technically shouldn't happen.
|
||||
log.Error("auth: unknown ip")
|
||||
}
|
||||
|
||||
if len(cookie) == 0 {
|
||||
var ip net.IP
|
||||
ip, err = realIP(r)
|
||||
if err != nil {
|
||||
log.Info("auth: getting real ip from request: %s", err)
|
||||
} else if ip == nil {
|
||||
// Technically shouldn't happen.
|
||||
log.Info("auth: failed to login user %q from unknown ip", req.Name)
|
||||
} else {
|
||||
log.Info("auth: failed to login user %q from ip %q", req.Name, ip)
|
||||
}
|
||||
log.Info("auth: failed to login user %q from ip %v", req.Name, ip)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
http.Error(w, "invalid username or password", http.StatusBadRequest)
|
||||
@@ -474,13 +472,15 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Set-Cookie", cookie)
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
h := w.Header()
|
||||
h.Set("Set-Cookie", cookie)
|
||||
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("Expires", "0")
|
||||
|
||||
returnOK(w)
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -528,7 +528,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("auth: authentification was handled by GL-Inet submodule")
|
||||
log.Debug("auth: authentication was handled by GL-Inet submodule")
|
||||
ok = true
|
||||
} else if err == nil {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -24,19 +25,17 @@ func TestMain(m *testing.M) {
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, sessionTokenSize)
|
||||
|
||||
// Break the rand.Reader.
|
||||
prevReader := rand.Reader
|
||||
t.Cleanup(func() {
|
||||
rand.Reader = prevReader
|
||||
})
|
||||
t.Cleanup(func() { rand.Reader = prevReader })
|
||||
rand.Reader = &bytes.Buffer{}
|
||||
|
||||
// Unsuccessful case.
|
||||
token, err = newSessionToken()
|
||||
require.NotNil(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, token)
|
||||
}
|
||||
|
||||
@@ -58,7 +57,7 @@ func TestAuth(t *testing.T) {
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
sess, err := newSessionToken()
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
@@ -152,7 +151,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
|
||||
// perform login
|
||||
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, cookie)
|
||||
|
||||
// get /
|
||||
@@ -251,12 +250,7 @@ func TestRealIP(t *testing.T) {
|
||||
ip, err := realIP(r)
|
||||
assert.Equal(t, tc.wantIP, ip)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,21 +15,19 @@ func TestAuthGL(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
GLMode = true
|
||||
t.Cleanup(func() {
|
||||
GLMode = false
|
||||
})
|
||||
t.Cleanup(func() { GLMode = false })
|
||||
glFilePrefix = dir + "/gl_token_"
|
||||
|
||||
data := make([]byte, 4)
|
||||
aghos.NativeEndian.PutUint32(data, 1)
|
||||
|
||||
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
assert.False(t, glCheckToken("test"))
|
||||
|
||||
data = make([]byte, 4)
|
||||
aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
|
||||
|
||||
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
|
||||
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
|
||||
assert.True(t, glProcessCookie(r))
|
||||
|
||||
@@ -95,7 +95,9 @@ type clientsContainer struct {
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
|
||||
etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files
|
||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
||||
// hosts database.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
testing bool // if TRUE, this object is used for internal tests
|
||||
}
|
||||
@@ -104,9 +106,9 @@ type clientsContainer struct {
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init(
|
||||
objects []clientObject,
|
||||
objects []*clientObject,
|
||||
dhcpServer *dhcpd.Server,
|
||||
etcHosts *aghnet.EtcHostsContainer,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
@@ -121,13 +123,22 @@ func (clients *clientsContainer) Init(
|
||||
clients.etcHosts = etcHosts
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
if !clients.testing {
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
if clients.etcHosts != nil {
|
||||
clients.etcHosts.SetOnChanged(clients.onHostsChanged)
|
||||
if clients.testing {
|
||||
return
|
||||
}
|
||||
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
if clients.etcHosts != nil {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,56 +175,65 @@ type clientObject struct {
|
||||
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) tagKnown(tag string) (ok bool) {
|
||||
return clients.allTags.Has(tag)
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
for _, cy := range objects {
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: cy.Name,
|
||||
IDs: cy.IDs,
|
||||
UseOwnSettings: !cy.UseGlobalSettings,
|
||||
FilteringEnabled: cy.FilteringEnabled,
|
||||
ParentalEnabled: cy.ParentalEnabled,
|
||||
SafeSearchEnabled: cy.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
||||
Name: o.Name,
|
||||
|
||||
UseOwnBlockedServices: !cy.UseGlobalBlockedServices,
|
||||
IDs: o.IDs,
|
||||
Upstreams: o.Upstreams,
|
||||
|
||||
Upstreams: cy.Upstreams,
|
||||
UseOwnSettings: !o.UseGlobalSettings,
|
||||
FilteringEnabled: o.FilteringEnabled,
|
||||
ParentalEnabled: o.ParentalEnabled,
|
||||
SafeSearchEnabled: o.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
||||
}
|
||||
|
||||
for _, s := range cy.BlockedServices {
|
||||
if !filtering.BlockedSvcKnown(s) {
|
||||
log.Debug("clients: skipping unknown blocked-service %q", s)
|
||||
continue
|
||||
for _, s := range o.BlockedServices {
|
||||
if filtering.BlockedSvcKnown(s) {
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown blocked service %q", s)
|
||||
}
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
}
|
||||
|
||||
for _, t := range cy.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
log.Debug("clients: skipping unknown tag %q", t)
|
||||
continue
|
||||
for _, t := range o.Tags {
|
||||
if clients.allTags.Has(t) {
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown tag %q", t)
|
||||
}
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
}
|
||||
|
||||
sort.Strings(cli.Tags)
|
||||
|
||||
_, err := clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Tracef("clientAdd: %s", err)
|
||||
log.Error("clients: adding clients %s: %s", cli.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||
// forConfig returns all currently known persistent clients as objects for the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
objs = make([]*clientObject, 0, len(clients.list))
|
||||
for _, cli := range clients.list {
|
||||
cy := clientObject{
|
||||
Name: cli.Name,
|
||||
o := &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
ParentalEnabled: cli.ParentalEnabled,
|
||||
@@ -222,14 +242,16 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||
}
|
||||
|
||||
cy.Tags = stringutil.CloneSlice(cli.Tags)
|
||||
cy.IDs = stringutil.CloneSlice(cli.IDs)
|
||||
cy.BlockedServices = stringutil.CloneSlice(cli.BlockedServices)
|
||||
cy.Upstreams = stringutil.CloneSlice(cli.Upstreams)
|
||||
|
||||
*objects = append(*objects, cy)
|
||||
objs = append(objs, o)
|
||||
}
|
||||
clients.lock.Unlock()
|
||||
|
||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
||||
// above loop can generate different orderings when writing to the config
|
||||
// file: this produces lots of diffs in config files, so sort objects by
|
||||
// name before writing.
|
||||
sort.Slice(objs, func(i, j int) bool { return objs[i].Name < objs[j].Name })
|
||||
|
||||
return objs
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
@@ -250,10 +272,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this IP address already exists.
|
||||
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
@@ -514,12 +532,12 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
c.IDs[i] = id
|
||||
} else {
|
||||
return fmt.Errorf("invalid client id at index %d: %q", i, id)
|
||||
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range c.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
if !clients.allTags.Has(t) {
|
||||
return fmt.Errorf("invalid tag: %q", t)
|
||||
}
|
||||
}
|
||||
@@ -697,7 +715,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||
clients.lock.Lock()
|
||||
@@ -757,13 +775,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile() {
|
||||
if clients.etcHosts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hosts := clients.etcHosts.List()
|
||||
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -771,17 +783,20 @@ func (clients *clientsContainer) addFromHostsFile() {
|
||||
|
||||
n := 0
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
names, ok := v.([]string)
|
||||
hosts, ok := v.(*stringutil.Set)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
if ok {
|
||||
hosts.Range(func(name string) (cont bool) {
|
||||
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -3,10 +3,12 @@ package home
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -271,12 +273,18 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("complicated", func(t *testing.T) {
|
||||
// TODO(a.garipov): Properly decouple the DHCP server from the client
|
||||
// storage.
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := dhcpd.ServerConfig{
|
||||
config := &dhcpd.ServerConfig{
|
||||
Enabled: true,
|
||||
DBFilePath: "leases.db",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
@@ -290,8 +298,9 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
clients.dhcpServer, err = dhcpd.Create(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { _ = os.Remove("leases.db") })
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return os.Remove("leases.db")
|
||||
})
|
||||
|
||||
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
@@ -309,8 +318,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP
|
||||
// range.
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.Add(&Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@@ -58,7 +59,7 @@ type clientListJSON struct {
|
||||
}
|
||||
|
||||
// respond with information about configured clients
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) {
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
data := clientListJSON{}
|
||||
|
||||
clients.lock.Lock()
|
||||
@@ -106,7 +107,14 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to encode to json: %v",
|
||||
e,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -154,7 +162,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -162,11 +170,14 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
c := jsonToClient(cj)
|
||||
ok, err := clients.Add(c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Client already exists")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,19 +189,19 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(cj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "client's name must be non-empty")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "client's name must be non-empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.Del(cj.Name) {
|
||||
httpError(w, http.StatusBadRequest, "Client not found")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,20 +218,22 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
dj := updateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(dj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "Invalid request")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Invalid request")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonToClient(dj.Data)
|
||||
err = clients.Update(dj.Name, c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -256,7 +269,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -82,10 +83,12 @@ type configuration struct {
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
||||
// Note: this array is filled only before file read/write and then it's cleared
|
||||
Clients []clientObject `yaml:"clients"`
|
||||
// Clients contains the YAML representations of the persistent clients.
|
||||
// This field is only used for reading and writing persistent client data.
|
||||
// Keep this field sorted to ensure consistent ordering.
|
||||
Clients []*clientObject `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
|
||||
@@ -120,11 +123,6 @@ type dnsConfig struct {
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
|
||||
// LocalDomainName is the domain name used for known internal hosts.
|
||||
// For example, a machine called "myhost" can be addressed as
|
||||
// "myhost.lan" when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
// ResolveClients enables and disables resolving clients with RDNS.
|
||||
ResolveClients bool `yaml:"resolve_clients"`
|
||||
|
||||
@@ -140,7 +138,7 @@ type dnsConfig struct {
|
||||
type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC int `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
@@ -163,7 +161,7 @@ type tlsConfigSettings struct {
|
||||
|
||||
// config is the global configuration structure.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This global is afwul and must be removed.
|
||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||
var config = &configuration{
|
||||
BindPort: 3000,
|
||||
BetaBindPort: 0,
|
||||
@@ -196,7 +194,6 @@ var config = &configuration{
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
LocalDomainName: "lan",
|
||||
ResolveClients: true,
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
@@ -205,6 +202,9 @@ var config = &configuration{
|
||||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||
PortDNSOverQUIC: defaultPortQUIC,
|
||||
},
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
},
|
||||
logSettings: logSettings{
|
||||
LogCompress: false,
|
||||
LogLocalTime: false,
|
||||
@@ -232,9 +232,9 @@ func initConfig() {
|
||||
config.DNS.DnsfilterConf.CacheTime = 30
|
||||
config.Filters = defaultFilters()
|
||||
|
||||
config.DHCP.Conf4.LeaseDuration = 86400
|
||||
config.DHCP.Conf4.ICMPTimeout = 1000
|
||||
config.DHCP.Conf6.LeaseDuration = 86400
|
||||
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
|
||||
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
|
||||
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
|
||||
config.BetaBindPort = 3001
|
||||
@@ -272,20 +272,40 @@ func getLogSettings() logSettings {
|
||||
}
|
||||
|
||||
// parseConfig loads configuration from the YAML file
|
||||
func parseConfig() error {
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Reading config file: %s", configFile)
|
||||
yamlFile, err := readConfigFile()
|
||||
func parseConfig() (err error) {
|
||||
var fileData []byte
|
||||
fileData, err = readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.fileData = nil
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
err = yaml.Unmarshal(fileData, &config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't parse config file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
config.BindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
)
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
@@ -297,18 +317,26 @@ func parseConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// readConfigFile reads config file contents if it exists
|
||||
func readConfigFile() ([]byte, error) {
|
||||
if len(config.fileData) != 0 {
|
||||
// addPorts is a helper for ports validation. It skips zero ports.
|
||||
func addPorts(uc aghalg.UniqChecker, ports ...int) {
|
||||
for _, p := range ports {
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readConfigFile reads configuration file contents.
|
||||
func readConfigFile() (fileData []byte, err error) {
|
||||
if len(config.fileData) > 0 {
|
||||
return config.fileData, nil
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
d, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't read config file %s: %w", configFile, err)
|
||||
}
|
||||
return d, nil
|
||||
name := config.getConfigFilename()
|
||||
log.Debug("reading config file: %s", name)
|
||||
|
||||
// Do not wrap the error because it's informative enough as is.
|
||||
return os.ReadFile(name)
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
@@ -316,8 +344,6 @@ func (c *configuration) write() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
Context.clients.WriteDiskConfig(&config.Clients)
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
@@ -360,15 +386,16 @@ func (c *configuration) write() error {
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
c := dhcpd.ServerConfig{}
|
||||
Context.dhcpServer.WriteDiskConfig(&c)
|
||||
c := &dhcpd.ServerConfig{}
|
||||
Context.dhcpServer.WriteDiskConfig(c)
|
||||
config.DHCP = c
|
||||
}
|
||||
|
||||
config.Clients = Context.clients.forConfig()
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
config.Clients = nil
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
@@ -17,23 +18,6 @@ import (
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
|
||||
// ----------------
|
||||
// helper functions
|
||||
// ----------------
|
||||
|
||||
func returnOK(w http.ResponseWriter) {
|
||||
_, err := fmt.Fprintf(w, "OK\n")
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info(text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS
|
||||
// addresses to a slice of strings.
|
||||
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
|
||||
@@ -53,7 +37,7 @@ func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
|
||||
|
||||
// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
|
||||
// dst. It also adds the IP addresses of all network interfaces if src contains
|
||||
// an unspecified IP addresss.
|
||||
// an unspecified IP address.
|
||||
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) {
|
||||
ifacesAdded := false
|
||||
for _, h := range src {
|
||||
@@ -125,12 +109,12 @@ type statusResponse struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
dnsAddrs, err := collectDNSAddresses()
|
||||
if err != nil {
|
||||
// Don't add a lot of formatting, since the error is already
|
||||
// wrapped by collectDNSAddresses.
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -165,7 +149,7 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -182,7 +166,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(data)
|
||||
@@ -295,7 +279,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
|
||||
host, err := netutil.SplitHost(r.Host)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -49,7 +50,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
fj := filterAddJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,7 +65,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// Check for duplicates
|
||||
if filterExists(fj.URL) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -79,17 +82,35 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
// Download the filter contents
|
||||
ok, err := f.update(&filt)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Couldn't fetch filter from url %s: %s",
|
||||
filt.URL,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
|
||||
if !ok {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Filter at the url %s is invalid (maybe it points to blank page?)",
|
||||
filt.URL,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// URL is assumed valid so append it to filters, update config, write new
|
||||
// file and reload it to engines.
|
||||
if !filterAdd(filt) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -98,7 +119,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +132,8 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
req := request{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to parse request body json: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,7 +174,7 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't write body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +194,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
fj := filterURLReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -228,7 +251,8 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -250,7 +274,8 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
req := Req{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -270,13 +295,15 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
|
||||
}()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -335,13 +362,14 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "http write: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,12 +378,14 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
req := filteringConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
||||
httpError(w, http.StatusBadRequest, "Unsupported interval")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -404,10 +434,19 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
setts.ProtectionEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"couldn't apply filtering: %s: %s",
|
||||
host,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -432,7 +471,8 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -13,28 +13,44 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// getAddrsResponse is the response for /install/get_addresses endpoint.
|
||||
type getAddrsResponse struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces map[string]*aghnet.NetInterface `json:"interfaces"`
|
||||
|
||||
// Version is the version of AdGuard Home.
|
||||
//
|
||||
// TODO(a.garipov): In the new API, rename this endpoint to something more
|
||||
// general, since there will be more information here than just network
|
||||
// interfaces.
|
||||
Version string `json:"version"`
|
||||
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
}
|
||||
|
||||
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
|
||||
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
data := getAddrsResponse{}
|
||||
data.WebPort = defaultPortHTTP
|
||||
data.DNSPort = defaultPortDNS
|
||||
data := getAddrsResponse{
|
||||
Version: version.Version(),
|
||||
|
||||
WebPort: defaultPortHTTP,
|
||||
DNSPort: defaultPortDNS,
|
||||
}
|
||||
|
||||
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -46,24 +62,31 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal default addresses to json: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type checkConfigReqEnt struct {
|
||||
Port int `json:"port"`
|
||||
type checkConfReqEnt struct {
|
||||
IP net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
type checkConfigReq struct {
|
||||
Web checkConfigReqEnt `json:"web"`
|
||||
DNS checkConfigReqEnt `json:"dns"`
|
||||
SetStaticIP bool `json:"set_static_ip"`
|
||||
type checkConfReq struct {
|
||||
Web checkConfReqEnt `json:"web"`
|
||||
DNS checkConfReqEnt `json:"dns"`
|
||||
SetStaticIP bool `json:"set_static_ip"`
|
||||
}
|
||||
|
||||
type checkConfigRespEnt struct {
|
||||
type checkConfRespEnt struct {
|
||||
Status string `json:"status"`
|
||||
CanAutofix bool `json:"can_autofix"`
|
||||
}
|
||||
@@ -74,63 +97,111 @@ type staticIPJSON struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type checkConfigResp struct {
|
||||
Web checkConfigRespEnt `json:"web"`
|
||||
DNS checkConfigRespEnt `json:"dns"`
|
||||
StaticIP staticIPJSON `json:"static_ip"`
|
||||
type checkConfResp struct {
|
||||
StaticIP staticIPJSON `json:"static_ip"`
|
||||
Web checkConfRespEnt `json:"web"`
|
||||
DNS checkConfRespEnt `json:"dns"`
|
||||
}
|
||||
|
||||
// Check if ports are available, respond with results
|
||||
// validateWeb returns error is the web part if the initial configuration can't
|
||||
// be set.
|
||||
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.Web.Port
|
||||
addPorts(uc, config.BetaBindPort, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
// Avoid duplicating the error into the status of DNS.
|
||||
uc[port] = 1
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 0, config.BindPort:
|
||||
return nil
|
||||
default:
|
||||
// Go on and check the port binding only if it's not zero or won't be
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, port)
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
// be set. canAutofix is true if the port can be unbound by AdGuard Home
|
||||
// automatically.
|
||||
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
addPorts(uc, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
case config.BindPort:
|
||||
// Go on and only check the UDP port since the TCP one is already bound
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
// Check TCP as well.
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, port)
|
||||
if !aghnet.IsAddrInUse(err) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Try to fix automatically.
|
||||
canAutofix = checkDNSStubListener()
|
||||
if canAutofix && req.DNS.Autofix {
|
||||
if derr := disableDNSStubListener(); derr != nil {
|
||||
log.Error("disabling DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, port)
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
return canAutofix, err
|
||||
}
|
||||
|
||||
// handleInstallCheckConfig handles the /check_config endpoint.
|
||||
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := checkConfigReq{}
|
||||
respData := checkConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
req := &checkConfReq{}
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding the request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort {
|
||||
err = aghnet.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
||||
if err != nil {
|
||||
respData.Web.Status = err.Error()
|
||||
}
|
||||
resp := &checkConfResp{}
|
||||
uc := aghalg.UniqChecker{}
|
||||
|
||||
if err = req.validateWeb(uc); err != nil {
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if reqData.DNS.Port != 0 {
|
||||
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
|
||||
if aghnet.ErrorIsAddrInUse(err) {
|
||||
canAutofix := checkDNSStubListener()
|
||||
if canAutofix && reqData.DNS.Autofix {
|
||||
|
||||
err = disableDNSStubListener()
|
||||
if err != nil {
|
||||
log.Error("Couldn't disable DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
respData.DNS.CanAutofix = canAutofix
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = aghnet.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
respData.DNS.Status = err.Error()
|
||||
} else if !reqData.DNS.IP.IsUnspecified() {
|
||||
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
|
||||
}
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(respData)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding the response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -251,10 +322,11 @@ type applyConfigReqEnt struct {
|
||||
}
|
||||
|
||||
type applyConfigReq struct {
|
||||
Web applyConfigReqEnt `json:"web"`
|
||||
DNS applyConfigReqEnt `json:"dns"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
|
||||
Web applyConfigReqEnt `json:"web"`
|
||||
DNS applyConfigReqEnt `json:"dns"`
|
||||
}
|
||||
|
||||
// copyInstallSettings copies the installation parameters between two
|
||||
@@ -270,40 +342,58 @@ func copyInstallSettings(dst, src *configuration) {
|
||||
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
|
||||
func shutdownSrv(ctx context.Context, srv *http.Server) {
|
||||
defer log.OnPanic("")
|
||||
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer cancel()
|
||||
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("error while shutting down http server %q: %s", srv.Addr, err)
|
||||
const msgFmt = "shutting down http server %q: %s"
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debug(msgFmt, srv.Addr, err)
|
||||
} else {
|
||||
log.Error(msgFmt, srv.Addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PasswordMinRunes is the minimum length of user's password in runes.
|
||||
const PasswordMinRunes = 8
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPacketPortAvailable(req.DNS.IP, req.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
if utf8.RuneCountInString(req.Password) < PasswordMinRunes {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusUnprocessableEntity,
|
||||
"password must be at least %d symbols long",
|
||||
PasswordMinRunes,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPortAvailable(req.DNS.IP, req.DNS.Port)
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -317,28 +407,29 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
config.DNS.BindHosts = []net.IP{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
// TODO(e.burkov): StartMods() should be put in a separate goroutine at
|
||||
// the moment we'll allow setting up TLS in the initial configuration or
|
||||
// the configuration itself will use HTTPS protocol, because the
|
||||
// underlying functions potentially restart the HTTPS server.
|
||||
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = StartMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u := User{}
|
||||
u.Name = req.Username
|
||||
Context.auth.UserAdd(&u, req.Password)
|
||||
u := &User{
|
||||
Name: req.Username,
|
||||
}
|
||||
Context.auth.UserAdd(u, req.Password)
|
||||
|
||||
err = config.write()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -349,19 +440,27 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
registerControlHandlers()
|
||||
|
||||
returnOK(w)
|
||||
aghhttp.OK(w)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Method http.(*Server).Shutdown needs to be called in a separate
|
||||
// goroutine and with its own context, because it waits until all
|
||||
// requests are handled and will be blocked by it's own caller.
|
||||
if restartHTTP {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
go shutdownSrv(ctx, cancel, web.httpServer)
|
||||
go shutdownSrv(ctx, cancel, web.httpServerBeta)
|
||||
if !restartHTTP {
|
||||
return
|
||||
}
|
||||
|
||||
// Method http.(*Server).Shutdown needs to be called in a separate goroutine
|
||||
// and with its own context, because it waits until all requests are handled
|
||||
// and will be blocked by it's own caller.
|
||||
go func(timeout time.Duration) {
|
||||
defer log.OnPanic("web")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
shutdownSrv(ctx, web.httpServer)
|
||||
shutdownSrv(ctx, web.httpServerBeta)
|
||||
}(shutdownTimeout)
|
||||
}
|
||||
|
||||
// decodeApplyConfigReq decodes the configuration, validates some parameters,
|
||||
@@ -380,7 +479,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
|
||||
restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPortAvailable(req.Web.IP, req.Web.Port)
|
||||
err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf(
|
||||
"checking address %s:%d: %w",
|
||||
@@ -406,8 +505,8 @@ func (web *Web) registerInstallHandlers() {
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default checkConfigReqEnt.
|
||||
type checkConfigReqEntBeta struct {
|
||||
Port int `json:"port"`
|
||||
IP []net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
@@ -431,24 +530,26 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
|
||||
reqData := checkConfigReqBeta{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
|
||||
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
nonBetaReqData := checkConfigReq{
|
||||
Web: checkConfigReqEnt{
|
||||
Port: reqData.Web.Port,
|
||||
nonBetaReqData := checkConfReq{
|
||||
Web: checkConfReqEnt{
|
||||
IP: reqData.Web.IP[0],
|
||||
Port: reqData.Web.Port,
|
||||
Autofix: reqData.Web.Autofix,
|
||||
},
|
||||
DNS: checkConfigReqEnt{
|
||||
Port: reqData.DNS.Port,
|
||||
DNS: checkConfReqEnt{
|
||||
IP: reqData.DNS.IP[0],
|
||||
Port: reqData.DNS.Port,
|
||||
Autofix: reqData.DNS.Autofix,
|
||||
},
|
||||
SetStaticIP: reqData.SetStaticIP,
|
||||
@@ -458,7 +559,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Failed to encode 'check_config' JSON data: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
body := nonBetaReqBody.String()
|
||||
@@ -484,10 +592,11 @@ type applyConfigReqEntBeta struct {
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default applyConfigReq.
|
||||
type applyConfigReqBeta struct {
|
||||
Web applyConfigReqEntBeta `json:"web"`
|
||||
DNS applyConfigReqEntBeta `json:"dns"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
|
||||
Web applyConfigReqEntBeta `json:"web"`
|
||||
DNS applyConfigReqEntBeta `json:"dns"`
|
||||
}
|
||||
|
||||
// handleInstallConfigureBeta is a substitution of /install/configure handler
|
||||
@@ -499,12 +608,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
|
||||
reqData := applyConfigReqBeta{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
|
||||
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -525,7 +636,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"Failed to encode 'check_config' JSON data: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
body := nonBetaReqBody.String()
|
||||
@@ -541,9 +659,9 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default firstRunData.
|
||||
type getAddrsResponseBeta struct {
|
||||
Interfaces []*aghnet.NetInterface `json:"interfaces"`
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces []*aghnet.NetInterface `json:"interfaces"`
|
||||
}
|
||||
|
||||
// handleInstallConfigureBeta is a substitution of /install/get_addresses
|
||||
@@ -552,13 +670,15 @@ type getAddrsResponseBeta struct {
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default handleInstallGetAddresses.
|
||||
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
|
||||
data := getAddrsResponseBeta{}
|
||||
data.WebPort = defaultPortHTTP
|
||||
data.DNSPort = defaultPortDNS
|
||||
data := getAddrsResponseBeta{
|
||||
WebPort: defaultPortHTTP,
|
||||
DNSPort: defaultPortDNS,
|
||||
}
|
||||
|
||||
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -567,7 +687,14 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal default addresses to json: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -43,7 +44,8 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -77,7 +79,15 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
// TODO(a.garipov): Figure out the purpose of %T verb.
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err)
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadGateway,
|
||||
"Couldn't get version check json from %s: %T %s\n",
|
||||
vcu,
|
||||
err,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -87,24 +97,26 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdate performs an update to the latest available version procedure.
|
||||
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
|
||||
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
if Context.updater.NewVersion() == "" {
|
||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.Update()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
aghhttp.OK(w)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -36,20 +37,24 @@ func onConfigModified() {
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer() error {
|
||||
var err error
|
||||
func initDNSServer() (err error) {
|
||||
baseDir := Context.getDataDir()
|
||||
|
||||
var anonFunc aghnet.IPMutFunc
|
||||
if config.DNS.AnonymizeClientIP {
|
||||
anonFunc = querylog.AnonymizeIP
|
||||
}
|
||||
anonymizer := aghnet.NewIPMut(anonFunc)
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.DNS.StatsInterval,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.DNS.StatsInterval,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
}
|
||||
Context.stats, err = stats.New(statsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize statistics module")
|
||||
return fmt.Errorf("init stats: %w", err)
|
||||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
@@ -62,6 +67,7 @@ func initDNSServer() error {
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
Anonymizer: anonymizer,
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
@@ -76,7 +82,8 @@ func initDNSServer() error {
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
SubnetDetector: Context.subnetDetector,
|
||||
LocalDomain: config.DNS.LocalDomainName,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
@@ -90,7 +97,8 @@ func initDNSServer() error {
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
dnsConfig, err := generateServerConfig()
|
||||
var dnsConfig dnsforward.ServerConfig
|
||||
dnsConfig, err = generateServerConfig()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
@@ -100,6 +108,7 @@ func initDNSServer() error {
|
||||
err = Context.dnsServer.Prepare(&dnsConfig)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||
}
|
||||
|
||||
@@ -202,7 +211,6 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
}
|
||||
|
||||
newConf.TLSv12Roots = Context.tlsRoots
|
||||
newConf.TLSCiphers = Context.tlsCiphers
|
||||
newConf.TLSAllowUnencryptedDoH = tlsConf.AllowUnencryptedDoH
|
||||
|
||||
newConf.FilterHandler = applyAdditionalFiltering
|
||||
@@ -310,7 +318,7 @@ func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filteri
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID)
|
||||
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
|
||||
@@ -710,8 +710,8 @@ func enableFilters(async bool) {
|
||||
}
|
||||
|
||||
func enableFiltersLocked(async bool) {
|
||||
var whiteFilters []filtering.Filter
|
||||
filters := []filtering.Filter{{
|
||||
ID: filtering.CustomListID,
|
||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||
}}
|
||||
|
||||
@@ -725,18 +725,20 @@ func enableFiltersLocked(async bool) {
|
||||
FilePath: filter.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
var allowFilters []filtering.Filter
|
||||
for _, filter := range config.WhitelistFilters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
whiteFilters = append(whiteFilters, filtering.Filter{
|
||||
allowFilters = append(allowFilters, filtering.Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil {
|
||||
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
||||
log.Debug("enabling filters: %s", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ const testFltsFileName = "1.txt"
|
||||
func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
n, werr := w.Write(*fltContent)
|
||||
require.NoError(t, werr)
|
||||
require.Equal(t, len(*fltContent), n)
|
||||
@@ -34,9 +35,7 @@ func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener)
|
||||
go func() {
|
||||
_ = http.Serve(l, h)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, l.Close())
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
return l
|
||||
}
|
||||
@@ -100,9 +99,7 @@ func TestFilters(t *testing.T) {
|
||||
|
||||
t.Run("refresh_actually", func(t *testing.T) {
|
||||
fltContent = []byte(`||example.com^`)
|
||||
t.Cleanup(func() {
|
||||
fltContent = []byte(content)
|
||||
})
|
||||
t.Cleanup(func() { fltContent = []byte(content) })
|
||||
|
||||
updateAndAssert(t, require.True, 1)
|
||||
})
|
||||
|
||||
@@ -19,8 +19,10 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -44,20 +46,25 @@ type homeContext struct {
|
||||
// Modules
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
||||
updater *updater.Updater
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
// (e.g. /etc/hosts) files.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
// hostsWatcher is the watcher to detect changes in the hosts files.
|
||||
hostsWatcher aghos.FSWatcher
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
|
||||
@@ -74,7 +81,6 @@ type homeContext struct {
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
tlsCiphers []uint16 // list of TLS ciphers to use
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
@@ -139,13 +145,13 @@ func setupContext(args options) {
|
||||
initConfig()
|
||||
|
||||
Context.tlsRoots = LoadSystemRootCAs()
|
||||
Context.tlsCiphers = InitTLSCiphers()
|
||||
Context.transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: aghtls.SaferCipherSuites(),
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
Context.client = &http.Client{
|
||||
@@ -154,14 +160,11 @@ func setupContext(args options) {
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
// Do the upgrade if necessary
|
||||
// Do the upgrade if necessary.
|
||||
err := upgradeConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fatalOnError(err)
|
||||
|
||||
err = parseConfig()
|
||||
if err != nil {
|
||||
if err = parseConfig(); err != nil {
|
||||
log.Error("parsing configuration file: %s", err)
|
||||
|
||||
os.Exit(1)
|
||||
@@ -179,15 +182,15 @@ func setupContext(args options) {
|
||||
|
||||
// logIfUnsupported logs a formatted warning if the error is one of the
|
||||
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
|
||||
// nil. Otherise, it returns err.
|
||||
// nil. Otherwise, it returns err.
|
||||
func logIfUnsupported(msg string, err error) (outErr error) {
|
||||
if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) {
|
||||
if errors.As(err, new(*aghos.UnsupportedError)) {
|
||||
log.Debug(msg, err)
|
||||
} else if err != nil {
|
||||
return err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// configureOS sets the OS-related configuration.
|
||||
@@ -230,6 +233,34 @@ func configureOS(conf *configuration) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupHostsContainer initializes the structures to keep up-to-date the hosts
|
||||
// provided by the OS.
|
||||
func setupHostsContainer() (err error) {
|
||||
Context.hostsWatcher, err = aghos.NewOSWritesWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initing hosts watcher: %w", err)
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
aghos.RootDirFS(),
|
||||
Context.hostsWatcher,
|
||||
aghnet.DefaultHostsPaths()...,
|
||||
)
|
||||
if err != nil {
|
||||
cerr := Context.hostsWatcher.Close()
|
||||
if errors.Is(err, aghnet.ErrNoHostsPaths) && cerr == nil {
|
||||
log.Info("warning: initing hosts container: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), cerr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(args options) (err error) {
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
@@ -257,19 +288,41 @@ func setupConfig(args options) (err error) {
|
||||
})
|
||||
|
||||
if !args.noEtcHosts {
|
||||
Context.etcHosts = &aghnet.EtcHostsContainer{}
|
||||
Context.etcHosts.Init("")
|
||||
if err = setupHostsContainer(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
|
||||
config.Clients = nil
|
||||
|
||||
if args.bindPort != 0 {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
args.bindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
)
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = args.bindPort
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
if args.bindHost != nil {
|
||||
config.BindHost = args.bindHost
|
||||
}
|
||||
if args.bindPort != 0 {
|
||||
config.BindPort = args.bindPort
|
||||
}
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
Context.pidFileName = args.pidFile
|
||||
}
|
||||
@@ -324,7 +377,7 @@ func fatalOnError(err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// run performs configurating and starts AdGuard Home.
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(args options, clientBuildFS fs.FS) {
|
||||
var err error
|
||||
|
||||
@@ -340,9 +393,9 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
// Go memory hacks
|
||||
memoryUsage(args)
|
||||
|
||||
// print the first message after logger is configured
|
||||
// Print the first message after logger is configured.
|
||||
log.Println(version.Full())
|
||||
log.Debug("Current working directory is %s", Context.workDir)
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
if args.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
}
|
||||
@@ -424,7 +477,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.Start()
|
||||
Context.etcHosts.Start()
|
||||
|
||||
go func() {
|
||||
serr := startDNSServer()
|
||||
@@ -579,13 +631,13 @@ func configureLogger(args options) {
|
||||
log.SetLevel(log.DEBUG)
|
||||
}
|
||||
|
||||
// Make sure that we see the microseconds in logs, as networking stuff
|
||||
// can happen pretty quickly.
|
||||
// Make sure that we see the microseconds in logs, as networking stuff can
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing else is configured
|
||||
// Otherwise, we'll simply loose the log output
|
||||
// When running as a Windows service, use eventlog by default if nothing
|
||||
// else is configured. Otherwise, we'll simply lose the log output.
|
||||
ls.LogFile = configSyslog
|
||||
}
|
||||
|
||||
@@ -647,7 +699,18 @@ func cleanup(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
Context.etcHosts.Close()
|
||||
if Context.etcHosts != nil {
|
||||
// Currently Context.hostsWatcher is only used in Context.etcHosts and
|
||||
// needs closing only in case of the successful initialization of
|
||||
// Context.etcHosts.
|
||||
if err = Context.hostsWatcher.Close(); err != nil {
|
||||
log.Error("closing hosts watcher: %s", err)
|
||||
}
|
||||
|
||||
if err = Context.etcHosts.Close(); err != nil {
|
||||
log.Error("closing hosts container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
Context.tls.Close()
|
||||
@@ -722,8 +785,7 @@ func printHTTPAddresses(proto string) {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous
|
||||
// condition.
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||
if proto == schemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Get rid of a global variable?
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
var allowedLanguages = stringutil.NewSet(
|
||||
"be",
|
||||
"bg",
|
||||
@@ -20,6 +21,7 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"en",
|
||||
"es",
|
||||
"fa",
|
||||
"fi",
|
||||
"fr",
|
||||
"hr",
|
||||
"hu",
|
||||
@@ -41,6 +43,7 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"sv",
|
||||
"th",
|
||||
"tr",
|
||||
"uk",
|
||||
"vi",
|
||||
"zh-cn",
|
||||
"zh-hk",
|
||||
@@ -94,5 +97,5 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
returnOK(w)
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestLimitRequestBody(t *testing.T) {
|
||||
var b []byte
|
||||
b, *err = io.ReadAll(r.Body)
|
||||
_, werr := w.Write(b)
|
||||
require.Nil(t, werr)
|
||||
require.NoError(t, werr)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user