Merge branch 'master' into 3717-fix-qq-blocked

This commit is contained in:
Ainar Garipov
2022-03-07 19:11:03 +03:00
232 changed files with 12754 additions and 8110 deletions

75
internal/aghalg/aghalg.go Normal file
View File

@@ -0,0 +1,75 @@
// Package aghalg contains common generic algorithms and data structures.
//
// TODO(a.garipov): Update to use type parameters in Go 1.18.
package aghalg
import (
"fmt"
"sort"
)
// comparable is an alias for interface{}. Values passed as arguments of this
// type alias must be comparable.
//
// TODO(a.garipov): Remove in Go 1.18.
type comparable = interface{}
// UniqChecker allows validating uniqueness of comparable items.
type UniqChecker map[comparable]int64
// Add adds a value to the validator. v must not be nil.
func (uc UniqChecker) Add(elems ...comparable) {
for _, e := range elems {
uc[e]++
}
}
// Merge returns a checker containing data from both uc and other.
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
merged = make(UniqChecker, len(uc)+len(other))
for elem, num := range uc {
merged[elem] += num
}
for elem, num := range other {
merged[elem] += num
}
return merged
}
// Validate returns an error enumerating all elements that aren't unique.
// isBefore is an optional sorting function to make the error message
// deterministic.
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
var dup []comparable
for elem, num := range uc {
if num > 1 {
dup = append(dup, elem)
}
}
if len(dup) == 0 {
return nil
}
if isBefore != nil {
sort.Slice(dup, func(i, j int) (less bool) {
return isBefore(dup[i], dup[j])
})
}
return fmt.Errorf("duplicated values: %v", dup)
}
// IntIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type int.
func IntIsBefore(a, b comparable) (less bool) {
return a.(int) < b.(int)
}
// StringIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type string.
func StringIsBefore(a, b comparable) (less bool) {
return a.(string) < b.(string)
}

View File

@@ -0,0 +1,24 @@
// Package aghhttp provides some common methods to work with HTTP.
package aghhttp
import (
"fmt"
"io"
"net/http"
"github.com/AdguardTeam/golibs/log"
)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {
log.Error("couldn't write body: %s", err)
}
}
// Error writes formatted message to w and also logs it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}

View File

@@ -11,7 +11,7 @@ type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
// Error implements the error interface for LimitReachedError.
//
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
@@ -35,7 +35,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
}
if int64(len(p)) > lr.n {
p = p[0:lr.n]
p = p[:lr.n]
}
n, err = lr.r.Read(p)

View File

@@ -1,38 +1,38 @@
package aghio
import (
"fmt"
"io"
"strings"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitReader(t *testing.T) {
testCases := []struct {
want error
name string
n int64
wantErrMsg string
name string
n int64
}{{
want: nil,
name: "positive",
n: 1,
wantErrMsg: "",
name: "positive",
n: 1,
}, {
want: nil,
name: "zero",
n: 0,
wantErrMsg: "",
name: "zero",
n: 0,
}, {
want: fmt.Errorf("aghio: invalid n in LimitReader: -1"),
name: "negative",
n: -1,
wantErrMsg: "aghio: invalid n in LimitReader: -1",
name: "negative",
n: -1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReader(nil, tc.n)
assert.Equal(t, tc.want, err)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
@@ -73,36 +73,23 @@ func TestLimitedReader_Read(t *testing.T) {
}}
for _, tc := range testCases {
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
lreader, err := LimitReader(readCloser, tc.limit)
require.NoError(t, err)
require.NotNil(t, lreader)
t.Run(tc.name, func(t *testing.T) {
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)
n, rerr := lreader.Read(buf)
require.Equal(t, rerr, tc.err)
lreader, err := LimitReader(readCloser, tc.limit)
require.NoError(t, err)
n, err := lreader.Read(buf)
require.Equal(t, tc.err, err)
assert.Equal(t, tc.want, n)
})
}
}
func TestLimitedReader_LimitReachedError(t *testing.T) {
testCases := []struct {
err error
name string
want string
}{{
err: &LimitReachedError{
Limit: 0,
},
name: "simplest",
want: "attempted to read more than 0 bytes",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, tc.err.Error())
})
}
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
Limit: 0,
})
}

View File

@@ -19,7 +19,8 @@ import (
"github.com/insomniacslk/dhcp/iana"
)
// defaultDiscoverTime is the
// defaultDiscoverTime is the default timeout of checking another DHCP server
// response.
const defaultDiscoverTime = 3 * time.Second
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
@@ -110,7 +111,7 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
// is spoiled.
//
// It's also known that listening on the specified interface's address
// ignores broadcasted packets when reading.
// ignores broadcast packets when reading.
var c net.PacketConn
if c, err = listenPacketReusable(iface.Name, "udp4", ":68"); err != nil {
return false, fmt.Errorf("couldn't listen on :68: %w", err)

View File

@@ -1,387 +0,0 @@
package aghnet
import (
"bufio"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
)
type onChangedT func()
// EtcHostsContainer - automatic DNS records
//
// TODO(e.burkov): Move the logic under interface. Refactor. Probably remove
// the resolving logic.
type EtcHostsContainer struct {
// lock protects table and tableReverse.
lock sync.RWMutex
// table is the host-to-IPs map.
table map[string][]net.IP
// tableReverse is the IP-to-hosts map. The type of the values in the
// map is []string.
tableReverse *netutil.IPMap
hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files
watcher *fsnotify.Watcher // file and directory watcher object
// onlyWritesChan used to contain only writing events from watcher.
onlyWritesChan chan fsnotify.Event
onChanged onChangedT // notification to other modules
}
// SetOnChanged - set callback function that will be called when the data is changed
func (ehc *EtcHostsContainer) SetOnChanged(onChanged onChangedT) {
ehc.onChanged = onChanged
}
// Notify other modules
func (ehc *EtcHostsContainer) notify() {
if ehc.onChanged == nil {
return
}
ehc.onChanged()
}
// Init - initialize
// hostsFn: Override default name for the hosts-file (optional)
func (ehc *EtcHostsContainer) Init(hostsFn string) {
ehc.table = make(map[string][]net.IP)
ehc.onlyWritesChan = make(chan fsnotify.Event, 2)
ehc.hostsFn = "/etc/hosts"
if runtime.GOOS == "windows" {
ehc.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
}
if len(hostsFn) != 0 {
ehc.hostsFn = hostsFn
}
if aghos.IsOpenWrt() {
// OpenWrt: "/tmp/hosts/dhcp.cfg01411c".
ehc.hostsDirs = append(ehc.hostsDirs, "/tmp/hosts")
}
// Load hosts initially
ehc.updateHosts()
var err error
ehc.watcher, err = fsnotify.NewWatcher()
if err != nil {
log.Error("etchosts: %s", err)
}
}
// Start - start module
func (ehc *EtcHostsContainer) Start() {
if ehc == nil {
return
}
log.Debug("Start etchostscontainer module")
ehc.updateHosts()
if ehc.watcher != nil {
go ehc.watcherLoop()
err := ehc.watcher.Add(ehc.hostsFn)
if err != nil {
log.Error("Error while initializing watcher for a file %s: %s", ehc.hostsFn, err)
}
for _, dir := range ehc.hostsDirs {
err = ehc.watcher.Add(dir)
if err != nil {
log.Error("Error while initializing watcher for a directory %s: %s", dir, err)
}
}
}
}
// Close - close module
func (ehc *EtcHostsContainer) Close() {
if ehc == nil {
return
}
if ehc.watcher != nil {
_ = ehc.watcher.Close()
}
// Don't close onlyWritesChan here and let onlyWrites close it after
// watcher.Events is closed to prevent close races.
}
// Process returns the list of IP addresses for the hostname or nil if nothing
// found.
func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
ehc.lock.RLock()
defer ehc.lock.RUnlock()
if ips, ok := ehc.table[host]; ok {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse processes a PTR request. It returns nil if nothing is found.
func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts []string) {
if qtype != dns.TypePTR {
return nil
}
ip, err := netutil.IPFromReversedAddr(addr)
if err != nil {
log.Error("etchosts: reversed addr: %s", err)
return nil
}
ehc.lock.RLock()
defer ehc.lock.RUnlock()
v, ok := ehc.tableReverse.Get(ip)
if !ok {
return nil
}
hosts, ok = v.([]string)
if !ok {
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
return nil
} else if len(hosts) == 0 {
return nil
}
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
return hosts
}
// List returns an IP-to-hostnames table. The type of the values in the map is
// []string. It is safe for concurrent use.
func (ehc *EtcHostsContainer) List() (ipToHosts *netutil.IPMap) {
ehc.lock.RLock()
defer ehc.lock.RUnlock()
return ehc.tableReverse.ShallowClone()
}
// update table
func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
ips, ok := table[host]
if ok {
for _, ip := range ips {
if ip.Equal(ipAddr) {
// IP already exists: don't add duplicates
ok = false
break
}
}
if !ok {
ips = append(ips, ipAddr)
table[host] = ips
}
} else {
table[host] = []net.IP{ipAddr}
ok = true
}
if ok {
log.Debug("etchosts: added %s -> %s", ipAddr, host)
}
}
// updateTableRev updates the reverse address table.
func (ehc *EtcHostsContainer) updateTableRev(tableRev *netutil.IPMap, newHost string, ip net.IP) {
v, ok := tableRev.Get(ip)
if !ok {
tableRev.Set(ip, []string{newHost})
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
return
}
hosts, _ := v.([]string)
for _, host := range hosts {
if host == newHost {
return
}
}
hosts = append(hosts, newHost)
tableRev.Set(ip, hosts)
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
}
// parseHostsLine parses hosts from the fields.
func parseHostsLine(fields []string) (hosts []string) {
for _, f := range fields {
hashIdx := strings.IndexByte(f, '#')
if hashIdx == 0 {
// The rest of the fields are a part of the comment.
// Skip immediately.
return
} else if hashIdx > 0 {
// Only a part of the field is a comment.
hosts = append(hosts, f[:hashIdx])
return hosts
}
hosts = append(hosts, f)
}
return hosts
}
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
// line for one IP are supported.
func (ehc *EtcHostsContainer) load(
table map[string][]net.IP,
tableRev *netutil.IPMap,
fn string,
) {
f, err := os.Open(fn)
if err != nil {
log.Error("etchosts: %s", err)
return
}
defer func() {
derr := f.Close()
if derr != nil {
log.Error("etchosts: closing file: %s", err)
}
}()
log.Debug("etchosts: loading hosts from file %s", fn)
s := bufio.NewScanner(f)
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
ip := net.ParseIP(fields[0])
if ip == nil {
continue
}
hosts := parseHostsLine(fields[1:])
for _, host := range hosts {
ehc.updateTable(table, host, ip)
ehc.updateTableRev(tableRev, host, ip)
}
}
err = s.Err()
if err != nil {
log.Error("etchosts: %s", err)
}
}
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
func (ehc *EtcHostsContainer) onlyWrites() {
for event := range ehc.watcher.Events {
if event.Op&fsnotify.Write == fsnotify.Write {
ehc.onlyWritesChan <- event
}
}
close(ehc.onlyWritesChan)
}
// Receive notifications from fsnotify package
func (ehc *EtcHostsContainer) watcherLoop() {
go ehc.onlyWrites()
for {
select {
case event, ok := <-ehc.onlyWritesChan:
if !ok {
return
}
// Assume that we sometimes have the same event occurred
// several times.
repeat := true
for repeat {
select {
case _, ok = <-ehc.onlyWritesChan:
repeat = ok
default:
repeat = false
}
}
if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("etchosts: modified: %s", event.Name)
ehc.updateHosts()
}
case err, ok := <-ehc.watcher.Errors:
if !ok {
return
}
log.Error("etchosts: %s", err)
}
}
}
// updateHosts - loads system hosts
func (ehc *EtcHostsContainer) updateHosts() {
table := make(map[string][]net.IP)
tableRev := netutil.NewIPMap(0)
ehc.load(table, tableRev, ehc.hostsFn)
for _, dir := range ehc.hostsDirs {
des, err := os.ReadDir(dir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.Error("etchosts: Opening directory: %q: %s", dir, err)
}
continue
}
for _, de := range des {
ehc.load(table, tableRev, filepath.Join(dir, de.Name()))
}
}
func() {
ehc.lock.Lock()
defer ehc.lock.Unlock()
ehc.table = table
ehc.tableReverse = tableRev
}()
ehc.notify()
}

View File

@@ -1,130 +0,0 @@
package aghnet
import (
"net"
"os"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func prepareTestFile(t *testing.T) (f *os.File) {
t.Helper()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "")
require.NoError(t, err)
require.NotNil(t, f)
t.Cleanup(func() {
assert.NoError(t, f.Close())
})
return f
}
func assertWriting(t *testing.T, f *os.File, strs ...string) {
t.Helper()
for _, str := range strs {
n, err := f.WriteString(str)
require.NoError(t, err)
assert.Equal(t, n, len(str))
}
}
func TestEtcHostsContainerResolution(t *testing.T) {
ehc := &EtcHostsContainer{}
f := prepareTestFile(t)
assertWriting(t, f,
" 127.0.0.1 host localhost # comment \n",
" ::1 localhost#comment \n",
)
ehc.Init(f.Name())
t.Run("existing_host", func(t *testing.T) {
ips := ehc.Process("localhost", dns.TypeA)
require.Len(t, ips, 1)
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
})
t.Run("unknown_host", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
// Comment.
ips = ehc.Process("comment", dns.TypeA)
assert.Nil(t, ips)
})
t.Run("hosts_file", func(t *testing.T) {
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
require.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names)
})
t.Run("ptr", func(t *testing.T) {
testCases := []struct {
wantIP string
wantHost string
wantLen int
}{
{wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
{wantIP: "::1", wantHost: "localhost", wantLen: 1},
}
for _, tc := range testCases {
a, err := dns.ReverseAddr(tc.wantIP)
require.NoError(t, err)
a = strings.TrimSuffix(a, ".")
hosts := ehc.ProcessReverse(a, dns.TypePTR)
require.Len(t, hosts, tc.wantLen)
assert.Equal(t, tc.wantHost, hosts[0])
}
})
}
func TestEtcHostsContainerFSNotify(t *testing.T) {
ehc := &EtcHostsContainer{}
f := prepareTestFile(t)
assertWriting(t, f, " 127.0.0.1 host localhost \n")
ehc.Init(f.Name())
t.Run("unknown_host", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
})
// Start monitoring for changes.
ehc.Start()
t.Cleanup(ehc.Close)
assertWriting(t, f, "127.0.0.2 newhost\n")
require.NoError(t, f.Sync())
// Wait until fsnotify has triggered and processed the file-modification
// event.
time.Sleep(50 * time.Millisecond)
t.Run("notified", func(t *testing.T) {
ips := ehc.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
require.Len(t, ips, 1)
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
})
}

View File

@@ -0,0 +1,504 @@
package aghnet
import (
"bufio"
"fmt"
"io"
"io/fs"
"net"
"path"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
// DefaultHostsPaths returns the slice of paths default for the operating system
// to files and directories which are containing the hosts database. The result
// is intended to be used within fs.FS so the initial slash is omitted.
func DefaultHostsPaths() (paths []string) {
return defaultHostsPaths()
}
// requestMatcher combines the logic for matching requests and translating the
// appropriate rules.
type requestMatcher struct {
// stateLock protects all the fields of requestMatcher.
stateLock *sync.RWMutex
// rulesStrg stores the rules obtained from the hosts' file.
rulesStrg *filterlist.RuleStorage
// engine serves rulesStrg.
engine *urlfilter.DNSEngine
// translator maps generated $dnsrewrite rules into hosts-syntax rules.
//
// TODO(e.burkov): Store the filename from which the rule was parsed.
translator map[string]string
}
// MatchRequest processes the request rewriting hostnames and addresses read
// from the operating system's hosts files. res is nil for any request having
// not an A/AAAA or PTR type, see man 5 hosts.
//
// It's safe for concurrent use.
func (rm *requestMatcher) MatchRequest(
req urlfilter.DNSRequest,
) (res *urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
log.Debug("%s: handling the request", hostsContainerPref)
default:
return nil, false
}
rm.stateLock.RLock()
defer rm.stateLock.RUnlock()
return rm.engine.MatchRequest(req)
}
// Translate returns the source hosts-syntax rule for the generated dnsrewrite
// rule or an empty string if the last doesn't exist. The returned rules are in
// a processed format like:
//
// ip host1 host2 ...
//
func (rm *requestMatcher) Translate(rule string) (hostRule string) {
rm.stateLock.RLock()
defer rm.stateLock.RUnlock()
return rm.translator[rule]
}
// resetEng updates container's engine and the translation map.
func (rm *requestMatcher) resetEng(rulesStrg *filterlist.RuleStorage, tr map[string]string) {
rm.stateLock.Lock()
defer rm.stateLock.Unlock()
rm.rulesStrg = rulesStrg
rm.engine = urlfilter.NewDNSEngine(rm.rulesStrg)
rm.translator = tr
}
// hostsContainerPref is a prefix for logging and wrapping errors in
// HostsContainer's methods.
const hostsContainerPref = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
type HostsContainer struct {
// requestMatcher matches the requests and translates the rules. It's
// embedded to implement MatchRequest and Translate for *HostsContainer.
//
// TODO(a.garipov, e.burkov): Consider fully merging into HostsContainer.
requestMatcher
// done is the channel to sign closing the container.
done chan struct{}
// updates is the channel for receiving updated hosts.
updates chan *netutil.IPMap
// last is the set of hosts that was cached within last detected change.
last *netutil.IPMap
// fsys is the working file system to read hosts files from.
fsys fs.FS
// w tracks the changes in specified files and directories.
w aghos.FSWatcher
// patterns stores specified paths in the fs.Glob-compatible form.
patterns []string
// listID is the identifier for the list of generated rules.
listID int
}
// ErrNoHostsPaths is returned when there are no valid paths to watch passed to
// the HostsContainer.
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
// NewHostsContainer creates a container of hosts, that watches the paths with
// w. listID is used as an identifier of the underlying rules list. paths
// shouldn't be empty and each of paths should locate either a file or a
// directory in fsys. fsys and w must be non-nil.
func NewHostsContainer(
listID int,
fsys fs.FS,
w aghos.FSWatcher,
paths ...string,
) (hc *HostsContainer, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }()
if len(paths) == 0 {
return nil, ErrNoHostsPaths
}
var patterns []string
patterns, err = pathsToPatterns(fsys, paths)
if err != nil {
return nil, err
} else if len(patterns) == 0 {
return nil, ErrNoHostsPaths
}
hc = &HostsContainer{
requestMatcher: requestMatcher{
stateLock: &sync.RWMutex{},
},
listID: listID,
done: make(chan struct{}, 1),
updates: make(chan *netutil.IPMap, 1),
fsys: fsys,
w: w,
patterns: patterns,
}
log.Debug("%s: starting", hostsContainerPref)
// Load initially.
if err = hc.refresh(); err != nil {
return nil, err
}
for _, p := range paths {
if err = w.Add(p); err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("adding path: %w", err)
}
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPref, p)
}
}
go hc.handleEvents()
return hc, nil
}
// Close implements the io.Closer interface for *HostsContainer. Close must
// only be called once. The returned err is always nil.
func (hc *HostsContainer) Close() (err error) {
log.Debug("%s: closing", hostsContainerPref)
close(hc.done)
return nil
}
// Upd returns the channel into which the updates are sent. The receivable
// map's values are guaranteed to be of type of *stringutil.Set.
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
return hc.updates
}
// pathsToPatterns converts paths into patterns compatible with fs.Glob.
func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) {
for i, p := range paths {
var fi fs.FileInfo
fi, err = fs.Stat(fsys, p)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
continue
}
// Don't put a filename here since it's already added by fs.Stat.
return nil, fmt.Errorf("path at index %d: %w", i, err)
}
if fi.IsDir() {
p = path.Join(p, "*")
}
patterns = append(patterns, p)
}
return patterns, nil
}
// handleEvents concurrently handles the file system events. It closes the
// update channel of HostsContainer when finishes. It's used to be called
// within a separate goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPref))
defer close(hc.updates)
ok, eventsCh := true, hc.w.Events()
for ok {
select {
case _, ok = <-eventsCh:
if !ok {
log.Debug("%s: watcher closed the events channel", hostsContainerPref)
continue
}
if err := hc.refresh(); err != nil {
log.Error("%s: %s", hostsContainerPref, err)
}
case _, ok = <-hc.done:
// Go on.
}
}
}
// hostsParser is a helper type to parse rules from the operating system's hosts
// file. It exists for only a single refreshing session.
type hostsParser struct {
// rulesBuilder builds the resulting rules list content.
rulesBuilder *strings.Builder
// translations maps generated rules into actual hosts file lines.
translations map[string]string
// table stores only the unique IP-hostname pairs. It's also sent to the
// updates channel afterwards.
table *netutil.IPMap
}
// newHostsParser creates a new *hostsParser with buffers of size taken from the
// previous parse.
func (hc *HostsContainer) newHostsParser() (hp *hostsParser) {
return &hostsParser{
rulesBuilder: &strings.Builder{},
translations: map[string]string{},
table: netutil.NewIPMap(hc.last.Len()),
}
}
// parseFile is a aghos.FileWalker for parsing the files with hosts syntax. It
// never signs to stop walking and never returns any additional patterns.
//
// See man hosts(5).
func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
ip, hosts := hp.parseLine(s.Text())
if ip == nil || len(hosts) == 0 {
continue
}
hp.addPairs(ip, hosts)
}
return nil, true, s.Err()
}
// parseLine parses the line having the hosts syntax ignoring invalid ones.
func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
fields := strings.Fields(line)
if len(fields) < 2 {
return nil, nil
}
if ip = net.ParseIP(fields[0]); ip == nil {
return nil, nil
}
for _, f := range fields[1:] {
hashIdx := strings.IndexByte(f, '#')
if hashIdx == 0 {
// The rest of the fields are a part of the comment so return.
break
} else if hashIdx > 0 {
// Only a part of the field is a comment.
f = f[:hashIdx]
}
// Make sure that invalid hosts aren't turned into rules.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
//
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
err := netutil.ValidateDomainName(f)
if err != nil {
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
continue
}
hosts = append(hosts, f)
}
return ip, hosts
}
// addPair puts the pair of ip and host to the rules builder if needed. For
// each ip the first member of hosts will become the main one.
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
v, ok := hp.table.Get(ip)
if !ok {
// This ip is added at the first time.
v = stringutil.NewSet()
hp.table.Set(ip, v)
}
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
if !ok {
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
return
}
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
for _, h := range hosts {
if set.Has(h) {
continue
}
set.Add(h)
rule, rulePtr := hp.writeRules(h, ip)
hp.translations[rule], hp.translations[rulePtr] = processed, processed
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
}
}
// writeRules writes the actual rule for the qtype and the PTR for the
// host-ip pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", ""
}
const (
nl = "\n"
rwSuccess = "^$dnsrewrite=NOERROR;"
rwSuccessPTR = "^$dnsrewrite=NOERROR;PTR;"
modLen = len(rules.MaskPipe) + len(rwSuccess) + len(";")
modLenPTR = len(rules.MaskPipe) + len(rwSuccessPTR)
)
var qtype string
// The validation of the IP address has been performed earlier so it is
// guaranteed to be either an IPv4 or an IPv6.
if ip.To4() != nil {
qtype = "A"
} else {
qtype = "AAAA"
}
ipStr := ip.String()
fqdn := dns.Fqdn(host)
ruleBuilder := &strings.Builder{}
ruleBuilder.Grow(modLen + len(host) + len(qtype) + len(ipStr))
stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, host, rwSuccess, qtype, ";", ipStr)
rule = ruleBuilder.String()
ruleBuilder.Reset()
ruleBuilder.Grow(modLenPTR + len(arpa) + len(fqdn))
stringutil.WriteToBuilder(ruleBuilder, rules.MaskPipe, arpa, rwSuccessPTR, fqdn)
rulePtr = ruleBuilder.String()
hp.rulesBuilder.Grow(len(rule) + len(rulePtr) + 2*len(nl))
stringutil.WriteToBuilder(hp.rulesBuilder, rule, nl, rulePtr, nl)
return rule, rulePtr
}
// equalSet returns true if the internal hosts table just parsed equals target.
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
if target == nil {
// hp.table shouldn't appear nil since it's initialized on each refresh.
return target == hp.table
}
if hp.table.Len() != target.Len() {
return false
}
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
// ok is set to true if the target doesn't contain ip or if the
// appropriate hosts set isn't equal to the checked one.
if a, hasIP := target.Get(ip); !hasIP {
ok = true
} else if hosts, aok := a.(*stringutil.Set); aok {
ok = !hosts.Equal(b.(*stringutil.Set))
}
// Continue only if maps has no discrepancies.
return !ok
})
// Return true if every value from the IP map has no discrepancies with the
// appropriate one from the target.
return !ok
}
// sendUpd tries to send the parsed data to the ch.
func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) {
log.Debug("%s: sending upd", hostsContainerPref)
upd := hp.table
select {
case ch <- upd:
// Updates are delivered. Go on.
case <-ch:
ch <- upd
log.Debug("%s: replaced the last update", hostsContainerPref)
case ch <- upd:
// The previous update was just read and the next one pushed. Go on.
default:
log.Error("%s: the updates channel is broken", hostsContainerPref)
}
}
// newStrg creates a new rules storage from parsed data.
func (hp *hostsParser) newStrg(id int) (s *filterlist.RuleStorage, err error) {
return filterlist.NewRuleStorage([]filterlist.RuleList{&filterlist.StringRuleList{
ID: id,
RulesText: hp.rulesBuilder.String(),
IgnoreCosmetic: true,
}})
}
// refresh gets the data from specified files and propagates the updates if
// needed.
//
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
func (hc *HostsContainer) refresh() (err error) {
log.Debug("%s: refreshing", hostsContainerPref)
hp := hc.newHostsParser()
if _, err = aghos.FileWalker(hp.parseFile).Walk(hc.fsys, hc.patterns...); err != nil {
return fmt.Errorf("refreshing : %w", err)
}
if hp.equalSet(hc.last) {
log.Debug("%s: no changes detected", hostsContainerPref)
return nil
}
defer hp.sendUpd(hc.updates)
hc.last = hp.table.ShallowClone()
var rulesStrg *filterlist.RuleStorage
if rulesStrg, err = hp.newStrg(hc.listID); err != nil {
return fmt.Errorf("initializing rules storage: %w", err)
}
hc.resetEng(rulesStrg, hp.translations)
return nil
}

View File

@@ -0,0 +1,18 @@
//go:build linux
// +build linux
package aghnet
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func defaultHostsPaths() (paths []string) {
paths = []string{"etc/hosts"}
if aghos.IsOpenWrt() {
paths = append(paths, "tmp/hosts")
}
return paths
}

View File

@@ -0,0 +1,8 @@
//go:build !(windows || linux)
// +build !windows,!linux
package aghnet
func defaultHostsPaths() (paths []string) {
return []string{"etc/hosts"}
}

View File

@@ -0,0 +1,612 @@
package aghnet
import (
"io/fs"
"net"
"os"
"path"
"strings"
"sync/atomic"
"testing"
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
nl = "\n"
sp = " "
)
func TestNewHostsContainer(t *testing.T) {
const dirname = "dir"
const filename = "file1"
p := path.Join(dirname, filename)
testFS := fstest.MapFS{
p: &fstest.MapFile{Data: []byte("127.0.0.1 localhost")},
}
testCases := []struct {
wantErr error
name string
paths []string
}{{
wantErr: nil,
name: "one_file",
paths: []string{p},
}, {
wantErr: ErrNoHostsPaths,
name: "no_files",
paths: []string{},
}, {
wantErr: ErrNoHostsPaths,
name: "non-existent_file",
paths: []string{path.Join(dirname, filename+"2")},
}, {
wantErr: nil,
name: "whole_dir",
paths: []string{dirname},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
onAdd := func(name string) (err error) {
assert.Contains(t, tc.paths, name)
return nil
}
var eventsCalledCounter uint32
eventsCh := make(chan struct{})
onEvents := func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return eventsCh
}
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
OnEvents: onEvents,
OnAdd: onAdd,
OnClose: func() (err error) { panic("not implemented") },
}, tc.paths...)
if tc.wantErr != nil {
require.ErrorIs(t, err, tc.wantErr)
assert.Nil(t, hc)
return
}
testutil.CleanupAndRequireSuccess(t, hc.Close)
require.NoError(t, err)
require.NotNil(t, hc)
assert.NotNil(t, <-hc.Upd())
eventsCh <- struct{}{}
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
})
}
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnClose: func() (err error) { return nil },
}, p)
})
})
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, testFS, nil, p)
})
})
t.Run("err_watcher", func(t *testing.T) {
const errOnAdd errors.Error = "error"
errWatcher := &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
OnAdd: func(name string) (err error) { return errOnAdd },
OnClose: func() (err error) { panic("not implemented") },
}
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
})
}
func TestHostsContainer_refresh(t *testing.T) {
// TODO(e.burkov): Test the case with no actual updates.
ip := net.IP{127, 0, 0, 1}
ipStr := ip.String()
testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}}
// event is a convenient alias for an empty struct{} to emit test events.
type event = struct{}
eventsCh := make(chan event, 1)
t.Cleanup(func() { close(eventsCh) })
w := &aghtest.FSWatcher{
OnEvents: func() (e <-chan event) { return eventsCh },
OnAdd: func(name string) (err error) {
assert.Equal(t, "dir", name)
return nil
},
OnClose: func() (err error) { panic("not implemented") },
}
hc, err := NewHostsContainer(0, testFS, w, "dir")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
upd, ok := <-hc.Upd()
require.True(t, ok)
require.NotNil(t, upd)
assert.Equal(t, 1, upd.Len())
v, ok := upd.Get(ip)
require.True(t, ok)
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
require.True(t, ok)
assert.True(t, set.Equal(wantHosts))
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, stringutil.NewSet("hostname"))
})
t.Run("second_refresh", func(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
})
t.Run("double_refresh", func(t *testing.T) {
// Make a change once.
testFS["dir/file1"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
// Require the changes are written.
require.Eventually(t, func() bool {
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
Hostname: "hostname",
DNSType: dns.TypeA,
})
return !ok && res.DNSRewrites() == nil
}, 5*time.Second, time.Second/2)
// Make a change again.
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}
eventsCh <- event{}
// Require the changes are written.
require.Eventually(t, func() bool {
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
Hostname: "hostname",
DNSType: dns.TypeA,
})
return !ok && res.DNSRewrites() != nil
}, 5*time.Second, time.Second/2)
assert.Len(t, hc.Upd(), 1)
})
}
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &aghtest.StatFS{
OnStat: func(name string) (fs.FileInfo, error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestHostsContainer_Translate(t *testing.T) {
testdata := os.DirFS("./testdata")
stubWatcher := aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnClose: func() (err error) { panic("not implemented") },
}
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
testCases := []struct {
name string
rule string
wantTrans []string
}{{
name: "simplehost",
rule: "|simplehost^$dnsrewrite=NOERROR;A;1.0.0.1",
wantTrans: []string{"1.0.0.1", "simplehost"},
}, {
name: "hello",
rule: "|hello^$dnsrewrite=NOERROR;A;1.0.0.0",
wantTrans: []string{"1.0.0.0", "hello", "hello.world"},
}, {
name: "hello-alias",
rule: "|hello.world.again^$dnsrewrite=NOERROR;A;1.0.0.0",
wantTrans: []string{"1.0.0.0", "hello.world.again"},
}, {
name: "simplehost_v6",
rule: "|simplehost^$dnsrewrite=NOERROR;AAAA;::1",
wantTrans: []string{"::1", "simplehost"},
}, {
name: "hello_v6",
rule: "|hello^$dnsrewrite=NOERROR;AAAA;::",
wantTrans: []string{"::", "hello", "hello.world"},
}, {
name: "hello_v6-alias",
rule: "|hello.world.again^$dnsrewrite=NOERROR;AAAA;::",
wantTrans: []string{"::", "hello.world.again"},
}, {
name: "simplehost_ptr",
rule: "|1.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;simplehost.",
wantTrans: []string{"1.0.0.1", "simplehost"},
}, {
name: "hello_ptr",
rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.",
wantTrans: []string{"1.0.0.0", "hello", "hello.world"},
}, {
name: "hello_ptr-alias",
rule: "|0.0.0.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;hello.world.again.",
wantTrans: []string{"1.0.0.0", "hello.world.again"},
}, {
name: "simplehost_ptr_v6",
rule: "|1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
"^$dnsrewrite=NOERROR;PTR;simplehost.",
wantTrans: []string{"::1", "simplehost"},
}, {
name: "hello_ptr_v6",
rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
"^$dnsrewrite=NOERROR;PTR;hello.",
wantTrans: []string{"::", "hello", "hello.world"},
}, {
name: "hello_ptr_v6-alias",
rule: "|0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" +
"^$dnsrewrite=NOERROR;PTR;hello.world.again.",
wantTrans: []string{"::", "hello.world.again"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := stringutil.NewSet(strings.Fields(hc.Translate(tc.rule))...)
assert.True(t, stringutil.NewSet(tc.wantTrans...).Equal(got))
})
}
}
func TestHostsContainer(t *testing.T) {
const listID = 1234
testdata := os.DirFS("./testdata")
testCases := []struct {
want []*rules.DNSRewrite
name string
req urlfilter.DNSRequest
}{{
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 1),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::1"),
RRType: dns.TypeAAAA,
}},
name: "simple",
req: urlfilter.DNSRequest{
Hostname: "simplehost",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "hello_alias",
req: urlfilter.DNSRequest{
Hostname: "hello.world",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "other_line_alias",
req: urlfilter.DNSRequest{
Hostname: "hello.world.again",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{},
name: "hello_subdomain",
req: urlfilter.DNSRequest{
Hostname: "say.hello",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{},
name: "hello_alias_subdomain",
req: urlfilter.DNSRequest{
Hostname: "say.hello.world",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
Value: net.IPv4(1, 0, 0, 2),
}, {
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
Value: net.ParseIP("::2"),
}},
name: "lots_of_aliases",
req: urlfilter.DNSRequest{
Hostname: "for.testing",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypePTR,
Value: "simplehost.",
}},
name: "reverse",
req: urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR,
},
}, {
want: []*rules.DNSRewrite{},
name: "non-existing",
req: urlfilter.DNSRequest{
Hostname: "nonexisting",
DNSType: dns.TypeA,
},
}, {
want: nil,
name: "bad_type",
req: urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypeSRV,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
Value: net.IPv4(4, 2, 1, 6),
}, {
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
Value: net.ParseIP("::42"),
}},
name: "issue_4216_4_6",
req: urlfilter.DNSRequest{
Hostname: "domain",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
Value: net.IPv4(7, 5, 3, 1),
}, {
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
Value: net.IPv4(1, 3, 5, 7),
}},
name: "issue_4216_4",
req: urlfilter.DNSRequest{
Hostname: "domain4",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
Value: net.ParseIP("::13"),
}, {
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
Value: net.ParseIP("::31"),
}},
name: "issue_4216_6",
req: urlfilter.DNSRequest{
Hostname: "domain6",
DNSType: dns.TypeAAAA,
},
}}
stubWatcher := aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnClose: func() (err error) { panic("not implemented") },
}
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, ok := hc.MatchRequest(tc.req)
require.False(t, ok)
if tc.want == nil {
assert.Nil(t, res)
return
}
require.NotNil(t, res)
rewrites := res.DNSRewrites()
require.Len(t, rewrites, len(tc.want))
for i, rewrite := range rewrites {
require.Equal(t, listID, rewrite.FilterListID)
rw := rewrite.DNSRewrite
require.NotNil(t, rw)
assert.Equal(t, tc.want[i], rw)
}
})
}
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := net.IP{127, 0, 0, 1}
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP net.IP
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: nil,
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: nil,
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: nil,
wantHosts: nil,
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.True(t, tc.wantIP.Equal(got))
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@@ -0,0 +1,33 @@
//go:build windows
// +build windows
package aghnet
import (
"os"
"path"
"path/filepath"
"strings"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/windows"
)
func defaultHostsPaths() (paths []string) {
sysDir, err := windows.GetSystemDirectory()
if err != nil {
log.Error("getting system directory: %s", err)
return []string{}
}
// Split all the elements of the path to join them afterwards. This is
// needed to make the Windows-specific path string returned by
// windows.GetSystemDirectory to be compatible with fs.FS.
pathElems := strings.Split(sysDir, string(os.PathSeparator))
if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) {
pathElems = pathElems[1:]
}
return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)}
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// IPVersion is a documentational alias for int. Use it when the integer means
// IP version.
// IPVersion is a alias for int for documentation purposes. Use it when the
// integer means IP version.
type IPVersion = int
// IP version constants.
@@ -25,6 +25,13 @@ type NetIface interface {
// IfaceIPAddrs returns the interface's IP addresses.
func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
switch ipv {
case IPVersion4, IPVersion6:
// Go on.
default:
return nil, fmt.Errorf("invalid ip version %d", ipv)
}
addrs, err := iface.Addrs()
if err != nil {
return nil, err
@@ -41,20 +48,16 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
continue
}
// Assume that net.(*Interface).Addrs can only return valid IPv4
// and IPv6 addresses. Thus, if it isn't an IPv4 address, it
// must be an IPv6 one.
switch ipv {
case IPVersion4:
if ip4 := ip.To4(); ip4 != nil {
// Assume that net.(*Interface).Addrs can only return valid IPv4 and
// IPv6 addresses. Thus, if it isn't an IPv4 address, it must be an
// IPv6 one.
ip4 := ip.To4()
if ipv == IPVersion4 {
if ip4 != nil {
ips = append(ips, ip4)
}
case IPVersion6:
if ip6 := ip.To4(); ip6 == nil {
ips = append(ips, ip)
}
default:
return nil, fmt.Errorf("invalid ip version %d", ipv)
} else if ip4 == nil {
ips = append(ips, ip)
}
}
@@ -67,7 +70,7 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
//
// It makes up to maxAttempts attempts to get the addresses if there are none,
// each time using the provided backoff. Sometimes an interface needs a few
// seconds to really ititialize.
// seconds to really initialize.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2304.
func IfaceDNSIPAddrs(
@@ -92,18 +95,20 @@ func IfaceDNSIPAddrs(
time.Sleep(backoff)
}
n--
switch len(addrs) {
case 0:
// Don't return errors in case the users want to try and enable
// the DHCP server later.
// Don't return errors in case the users want to try and enable the DHCP
// server later.
t := time.Duration(n) * backoff
log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t)
return nil, nil
case 1:
// Some Android devices use 8.8.8.8 if there is not a secondary
// DNS server. Fix that by setting the secondary DNS address to
// the same address.
// Some Android devices use 8.8.8.8 if there is not a secondary DNS
// server. Fix that by setting the secondary DNS address to the same
// address.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/1708.
log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv)
@@ -116,3 +121,11 @@ func IfaceDNSIPAddrs(
return addrs, nil
}
// interfaceName is a string containing network interface's name. The name is
// used in file walking methods.
type interfaceName string
// Use interfaceName in the OS-independent code since it's actually only used in
// several OS-dependent implementations which causes linting issues.
var _ = interfaceName("")

View File

@@ -5,13 +5,15 @@ import (
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeIface is a stub implementation of aghnet.NetIface to simplify testing.
type fakeIface struct {
addrs []net.Addr
err error
addrs []net.Addr
}
// Addrs implements the NetIface interface for *fakeIface.
@@ -33,61 +35,86 @@ func TestIfaceIPAddrs(t *testing.T) {
addr6 := &net.IPNet{IP: ip6}
testCases := []struct {
name string
iface NetIface
ipv IPVersion
want []net.IP
wantErr error
iface NetIface
name string
wantErrMsg string
want []net.IP
ipv IPVersion
}{{
name: "ipv4_success",
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
ipv: IPVersion4,
want: []net.IP{ip4},
wantErr: nil,
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
name: "ipv4_success",
wantErrMsg: "",
want: []net.IP{ip4},
ipv: IPVersion4,
}, {
name: "ipv4_success_with_ipv6",
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
ipv: IPVersion4,
want: []net.IP{ip4},
wantErr: nil,
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
name: "ipv4_success_with_ipv6",
wantErrMsg: "",
want: []net.IP{ip4},
ipv: IPVersion4,
}, {
name: "ipv4_error",
iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest},
ipv: IPVersion4,
want: nil,
wantErr: errTest,
iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest},
name: "ipv4_error",
wantErrMsg: errTest.Error(),
want: nil,
ipv: IPVersion4,
}, {
name: "ipv6_success",
iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil},
ipv: IPVersion6,
want: []net.IP{ip6},
wantErr: nil,
iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil},
name: "ipv6_success",
wantErrMsg: "",
want: []net.IP{ip6},
ipv: IPVersion6,
}, {
name: "ipv6_success_with_ipv4",
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
ipv: IPVersion6,
want: []net.IP{ip6},
wantErr: nil,
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
name: "ipv6_success_with_ipv4",
wantErrMsg: "",
want: []net.IP{ip6},
ipv: IPVersion6,
}, {
name: "ipv6_error",
iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest},
ipv: IPVersion6,
want: nil,
wantErr: errTest,
iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest},
name: "ipv6_error",
wantErrMsg: errTest.Error(),
want: nil,
ipv: IPVersion6,
}, {
iface: &fakeIface{addrs: nil, err: nil},
name: "bad_proto",
wantErrMsg: "invalid ip version 10",
want: nil,
ipv: IPVersion6 + IPVersion4,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip4}}, err: nil},
name: "ipaddr_v4",
wantErrMsg: "",
want: []net.IP{ip4},
ipv: IPVersion4,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip6, Zone: ""}}, err: nil},
name: "ipaddr_v6",
wantErrMsg: "",
want: []net.IP{ip6},
ipv: IPVersion6,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.UnixAddr{}}, err: nil},
name: "non-ipv4",
wantErrMsg: "",
want: nil,
ipv: IPVersion4,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, gotErr := IfaceIPAddrs(tc.iface, tc.ipv)
require.True(t, errors.Is(gotErr, tc.wantErr))
got, err := IfaceIPAddrs(tc.iface, tc.ipv)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
}
type waitingFakeIface struct {
addrs []net.Addr
err error
addrs []net.Addr
n int
}
@@ -116,11 +143,11 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
addr6 := &net.IPNet{IP: ip6}
testCases := []struct {
name string
iface NetIface
ipv IPVersion
want []net.IP
wantErr error
name string
want []net.IP
ipv IPVersion
}{{
name: "ipv4_success",
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
@@ -169,12 +196,25 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
ipv: IPVersion6,
want: []net.IP{ip6, ip6},
wantErr: nil,
}, {
name: "empty",
iface: &fakeIface{addrs: nil, err: nil},
ipv: IPVersion4,
want: nil,
wantErr: nil,
}, {
name: "many",
iface: &fakeIface{addrs: []net.Addr{addr4, addr4}},
ipv: IPVersion4,
want: []net.IP{ip4, ip4},
wantErr: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, gotErr := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
require.True(t, errors.Is(gotErr, tc.wantErr))
got, err := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
require.ErrorIs(t, err, tc.wantErr)
assert.Equal(t, tc.want, got)
})
}

43
internal/aghnet/ipmut.go Normal file
View File

@@ -0,0 +1,43 @@
package aghnet
import (
"net"
"sync/atomic"
)
// IPMutFunc is the signature of a function which modifies the IP address
// instance. It should be safe for concurrent use.
type IPMutFunc func(ip net.IP)
// nopIPMutFunc is the IPMutFunc that does nothing.
func nopIPMutFunc(net.IP) {}
// IPMut is a type-safe wrapper of atomic.Value to store the IPMutFunc.
type IPMut struct {
f atomic.Value
}
// NewIPMut returns the new properly initialized *IPMut. The m is guaranteed to
// always store non-nil IPMutFunc which is safe to call.
func NewIPMut(f IPMutFunc) (m *IPMut) {
m = &IPMut{
f: atomic.Value{},
}
m.Store(f)
return m
}
// Store sets the IPMutFunc to return from Func. It's safe for concurrent use.
// If f is nil, the stored function is the no-op one.
func (m *IPMut) Store(f IPMutFunc) {
if f == nil {
f = nopIPMutFunc
}
m.f.Store(f)
}
// Load returns the previously stored IPMutFunc.
func (m *IPMut) Load() (f IPMutFunc) {
return m.f.Load().(IPMutFunc)
}

View File

@@ -0,0 +1,44 @@
package aghnet
import (
"net"
"testing"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
)
func TestIPMut(t *testing.T) {
testIPs := []net.IP{{
127, 0, 0, 1,
}, {
192, 168, 0, 1,
}, {
8, 8, 8, 8,
}}
t.Run("nil_no_mut", func(t *testing.T) {
ipmut := NewIPMut(nil)
ips := netutil.CloneIPs(testIPs)
for i := range ips {
ipmut.Load()(ips[i])
assert.True(t, ips[i].Equal(testIPs[i]))
}
})
t.Run("not_nil_mut", func(t *testing.T) {
ipmut := NewIPMut(func(ip net.IP) {
for i := range ip {
ip[i] = 0
}
})
want := netutil.IPv4Zero()
ips := netutil.CloneIPs(testIPs)
for i := range ips {
ipmut.Load()(ips[i])
assert.True(t, ips[i].Equal(want))
}
})
}

View File

@@ -20,7 +20,12 @@ type IpsetManager interface {
//
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
//
// The error is of type *aghos.UnsupportedError if the OS is not supported.
// If ipsetConf is empty, msg and err are nil. The error is of type
// *aghos.UnsupportedError if the OS is not supported.
func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) {
if len(ipsetConf) == 0 {
return nil, nil
}
return newIpsetMgr(ipsetConf)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
"golang.org/x/sys/unix"
)
// How to test on a real Linux machine:
@@ -42,11 +43,17 @@ import (
// newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
dial := func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
return ipset.Dial(pf, conf)
return newIpsetMgrWithDialer(ipsetConf, defaultDial)
}
// defaultDial is the default netfilter dialing function.
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
conn, err = ipset.Dial(pf, conf)
if err != nil {
return nil, err
}
return newIpsetMgrWithDialer(ipsetConf, dial)
return conn, nil
}
// ipsetConn is the ipset conn interface.
@@ -103,8 +110,8 @@ func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) {
// The kernel API does not actually require two sockets but package
// github.com/digineo/go-ipset does.
//
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and
// just use packages netfilter and netlink.
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and just
// use packages netfilter and netlink.
m.ipv4Conn, err = m.dial(netfilter.ProtoIPv4, conf)
if err != nil {
return fmt.Errorf("dialing v4: %w", err)
@@ -214,6 +221,14 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
err = m.dialNetfilter(&netlink.Config{})
if err != nil {
if errors.Is(err, unix.EPROTONOSUPPORT) {
// The implementation doesn't support this protocol version. Just
// issue a warning.
log.Info("ipset: dialing netfilter: warning: %s", err)
return nil, nil
}
return nil, fmt.Errorf("dialing netfilter: %w", err)
}

View File

@@ -4,13 +4,11 @@ package aghnet
import (
"encoding/json"
"fmt"
"io"
"net"
"os"
"os/exec"
"runtime"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -44,8 +42,7 @@ func GatewayIP(ifaceName string) net.IP {
fields := strings.Fields(string(d))
// The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third
// field.
// "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || fields[0] != "default" {
return nil
}
@@ -189,79 +186,35 @@ func GetSubnet(ifaceName string) *net.IPNet {
return nil
}
// CheckPortAvailable - check if TCP port is available
func CheckPortAvailable(host net.IP, port int) error {
ln, err := net.Listen("tcp", netutil.JoinHostPort(host.String(), port))
// CheckPort checks if the port is available for binding. network is expected
// to be one of "udp" and "tcp".
func CheckPort(network string, ip net.IP, port int) (err error) {
var c io.Closer
addr := netutil.IPPort{IP: ip, Port: port}.String()
switch network {
case "tcp":
c, err = net.Listen(network, addr)
case "udp":
c, err = net.ListenPacket(network, addr)
default:
return nil
}
if err != nil {
return err
}
_ = ln.Close()
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
time.Sleep(100 * time.Millisecond)
return nil
return closePortChecker(c)
}
// CheckPacketPortAvailable - check if UDP port is available
func CheckPacketPortAvailable(host net.IP, port int) error {
ln, err := net.ListenPacket("udp", netutil.JoinHostPort(host.String(), port))
if err != nil {
return err
}
_ = ln.Close()
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
time.Sleep(100 * time.Millisecond)
return err
}
// ErrorIsAddrInUse - check if error is "address already in use"
func ErrorIsAddrInUse(err error) bool {
errOpError, ok := err.(*net.OpError)
if !ok {
// IsAddrInUse checks if err is about unsuccessful address binding.
func IsAddrInUse(err error) (ok bool) {
var sysErr syscall.Errno
if !errors.As(err, &sysErr) {
return false
}
errSyscallError, ok := errOpError.Err.(*os.SyscallError)
if !ok {
return false
}
errErrno, ok := errSyscallError.Err.(syscall.Errno)
if !ok {
return false
}
if runtime.GOOS == "windows" {
const WSAEADDRINUSE = 10048
return errErrno == WSAEADDRINUSE
}
return errErrno == syscall.EADDRINUSE
}
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
// does not necessarily contain a port.
func SplitHost(hostport string) (host string, err error) {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
// Check for the missing port error. If it is that error, just
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"
addrErr := &net.AddrError{}
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
return "", err
}
host = hostport
}
return host, nil
return isAddrInUse(sysErr)
}
// CollectAllIfacesAddrs returns the slice of all network interfaces IP

View File

@@ -18,9 +18,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const filename = "/etc/rc.conf"
const rcConfFilename = "etc/rc.conf"
return aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig).Walk(filename)
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
}
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to

View File

@@ -85,17 +85,17 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
iface := interfaceName(ifaceName)
for _, pair := range []struct {
for _, pair := range [...]struct {
aghos.FileWalker
filename string
}{{
FileWalker: iface.dhcpcdStaticConfig,
filename: "/etc/dhcpcd.conf",
filename: "etc/dhcpcd.conf",
}, {
FileWalker: iface.ifacesStaticConfig,
filename: "/etc/network/interfaces",
filename: "etc/network/interfaces",
}} {
has, err = pair.Walk(pair.filename)
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
if err != nil {
return false, err
}

View File

@@ -12,8 +12,6 @@ import (
"github.com/stretchr/testify/require"
)
const nl = "\n"
func TestDHCPCDStaticConfig(t *testing.T) {
const iface interfaceName = `wlan0`

View File

@@ -18,9 +18,9 @@ func canBindPrivilegedPorts() (can bool, err error) {
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("/etc/hostname.%s", ifaceName)
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(filename)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
}
// hostnameIfStaticConfig checks if the interface is configured by

View File

@@ -1,20 +0,0 @@
//go:build !(linux || darwin || freebsd || openbsd)
// +build !linux,!darwin,!freebsd,!openbsd
package aghnet
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(string) (ok bool, err error) {
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(string) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@@ -4,16 +4,31 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetValidNetInterfacesForWeb(t *testing.T) {
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestGetInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoErrorf(t, err, "cannot get net interfaces: %s", err)
require.NotEmpty(t, ifaces, "no net interfaces found")
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name)
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := GetInterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
@@ -64,3 +79,49 @@ func TestBroadcastFromIPNet(t *testing.T) {
})
}
}
func TestCheckPort(t *testing.T) {
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := netutil.IPPortFromAddr(l.Addr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
err = CheckPort("tcp", ipp.IP, ipp.Port)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", "127.0.0.1:")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
err = CheckPort("udp", ipp.IP, ipp.Port)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", nil, 0)
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
assert.NoError(t, err)
})
}

View File

@@ -1,8 +1,20 @@
//go:build openbsd || freebsd || linux
// +build openbsd freebsd linux
//go:build openbsd || freebsd || linux || darwin
// +build openbsd freebsd linux darwin
package aghnet
// interfaceName is a string containing network interface's name. The name is
// used in file walking methods.
type interfaceName string
import (
"io"
"syscall"
"github.com/AdguardTeam/golibs/errors"
)
// closePortChecker closes c. c must be non-nil.
func closePortChecker(c io.Closer) (err error) {
return c.Close()
}
func isAddrInUse(err syscall.Errno) (ok bool) {
return errors.Is(err, syscall.EADDRINUSE)
}

View File

@@ -0,0 +1,46 @@
//go:build !(linux || darwin || freebsd || openbsd)
// +build !linux,!darwin,!freebsd,!openbsd
package aghnet
import (
"io"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/windows"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(string) (ok bool, err error) {
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(string) (err error) {
return aghos.Unsupported("setting static ip")
}
// closePortChecker closes c. c must be non-nil.
func closePortChecker(c io.Closer) (err error) {
if err = c.Close(); err != nil {
return err
}
// It seems that net.Listener.Close() doesn't close file descriptors right
// away. We wait for some time and hope that this fd will be closed.
//
// TODO(e.burkov): Investigate the purpose of the line and perhaps use more
// reliable approach.
time.Sleep(100 * time.Millisecond)
return nil
}
func isAddrInUse(err syscall.Errno) (ok bool) {
return errors.Is(err, windows.WSAEADDRINUSE)
}

View File

@@ -6,16 +6,18 @@ import (
// SubnetDetector describes IP address properties.
type SubnetDetector struct {
// spNets is the slice of special-purpose address registries as defined
// by RFC-6890 (https://tools.ietf.org/html/rfc6890).
// spNets is the collection of special-purpose address registries as defined
// by RFC 6890.
spNets []*net.IPNet
// locServedNets is the slice of locally-served networks as defined by
// RFC-6303 (https://tools.ietf.org/html/rfc6303).
// locServedNets is the collection of locally-served networks as defined by
// RFC 6303.
locServedNets []*net.IPNet
}
// NewSubnetDetector returns a new IP detector.
//
// TODO(a.garipov): Decide whether an error is actually needed.
func NewSubnetDetector() (snd *SubnetDetector, err error) {
spNets := []string{
// "This" network.

View File

@@ -79,8 +79,8 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)
require.Nil(t, conn)
assert.ErrorIs(t, err, tc.want)
})
}

38
internal/aghnet/testdata/etc_hosts vendored Normal file
View File

@@ -0,0 +1,38 @@
#
# Test /etc/hosts file
#
1.0.0.1 simplehost
1.0.0.0 hello hello.world
# See https://github.com/AdguardTeam/AdGuardHome/issues/3846.
1.0.0.2 a.whole lot.of aliases for.testing
# See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
1.0.0.3 *
1.0.0.4 *.com
# See https://github.com/AdguardTeam/AdGuardHome/issues/4079.
1.0.0.0 hello.world.again
# Duplicates of a main host and an alias.
1.0.0.1 simplehost
1.0.0.0 hello.world
# Same for IPv6.
::1 simplehost
:: hello hello.world
::2 a.whole lot.of aliases for.testing
::3 *
::4 *.com
:: hello.world.again
::1 simplehost
:: hello.world
# See https://github.com/AdguardTeam/AdGuardHome/issues/4216.
4.2.1.6 domain domain.alias
::42 domain.alias domain
1.3.5.7 domain4 domain4.alias
7.5.3.1 domain4.alias domain4
::13 domain6 domain6.alias
::31 domain6.alias domain6

View File

@@ -3,10 +3,8 @@ package aghos
import (
"fmt"
"io"
"os"
"path/filepath"
"io/fs"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
)
@@ -14,27 +12,28 @@ import (
// FileWalker is the signature of a function called for files in the file tree.
// As opposed to filepath.Walk it only walk the files (not directories) matching
// the provided pattern and those returned by function itself. All patterns
// should be valid for filepath.Glob. If cont is false, the walking terminates.
// Each opened file is also limited for reading to MaxWalkedFileSize.
// should be valid for fs.Glob. If FileWalker returns false for cont then
// walking terminates. Prefer using bufio.Scanner to read the r since the input
// is not limited.
//
// TODO(e.burkov): Consider moving to the separate package like pathutil.
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
//
// TODO(e.burkov): Think about passing filename or any additional data.
type FileWalker func(r io.Reader) (patterns []string, cont bool, err error)
// MaxWalkedFileSize is the maximum length of the file that FileWalker can
// check.
const MaxWalkedFileSize = 1024 * 1024
// checkFile tries to open and process a single file located on sourcePath.
func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, err error) {
var f *os.File
f, err = os.Open(sourcePath)
// checkFile tries to open and process a single file located on sourcePath in
// the specified fsys. The path is skipped if it's a directory.
func checkFile(
fsys fs.FS,
c FileWalker,
sourcePath string,
) (patterns []string, cont bool, err error) {
var f fs.File
f, err = fsys.Open(sourcePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
// Ignore non-existing files since this may only happen
// when the file was removed after filepath.Glob matched
// it.
if errors.Is(err, fs.ErrNotExist) {
// Ignore non-existing files since this may only happen when the
// file was removed after filepath.Glob matched it.
return nil, true, nil
}
@@ -42,23 +41,28 @@ func checkFile(c FileWalker, sourcePath string) (patterns []string, cont bool, e
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var r io.Reader
// Ignore the error since LimitReader function returns error only if
// passed limit value is less than zero, but the constant used.
//
// TODO(e.burkov): Make variable.
r, _ = aghio.LimitReader(f, MaxWalkedFileSize)
var fi fs.FileInfo
if fi, err = f.Stat(); err != nil {
return nil, true, err
} else if fi.IsDir() {
// Skip the directories.
return nil, true, nil
}
return c(r)
return c(f)
}
// handlePatterns parses the patterns and ignores duplicates using srcSet.
// srcSet must be non-nil.
func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, err error) {
// handlePatterns parses the patterns in fsys and ignores duplicates using
// srcSet. srcSet must be non-nil.
func handlePatterns(
fsys fs.FS,
srcSet *stringutil.Set,
patterns ...string,
) (sub []string, err error) {
sub = make([]string, 0, len(patterns))
for _, p := range patterns {
var matches []string
matches, err = filepath.Glob(p)
matches, err = fs.Glob(fsys, p)
if err != nil {
// Enrich error with the pattern because filepath.Glob
// doesn't do it.
@@ -78,14 +82,14 @@ func handlePatterns(srcSet *stringutil.Set, patterns ...string) (sub []string, e
return sub, nil
}
// Walk starts walking the files defined by initPattern. It only returns true
// if c signed to stop walking.
func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
// The slice of sources keeps the order in which the files are walked
// since srcSet.Values() returns strings in undefined order.
// Walk starts walking the files in fsys defined by patterns from initial.
// It only returns true if fw signed to stop walking.
func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
// The slice of sources keeps the order in which the files are walked since
// srcSet.Values() returns strings in undefined order.
srcSet := stringutil.NewSet()
var src []string
src, err = handlePatterns(srcSet, initPattern)
src, err = handlePatterns(fsys, srcSet, initial...)
if err != nil {
return false, err
}
@@ -97,7 +101,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
var patterns []string
var cont bool
filename = src[i]
patterns, cont, err = checkFile(c, src[i])
patterns, cont, err = checkFile(fsys, fw, src[i])
if err != nil {
return false, err
}
@@ -107,7 +111,7 @@ func (c FileWalker) Walk(initPattern string) (ok bool, err error) {
}
var subsrc []string
subsrc, err = handlePatterns(srcSet, patterns...)
subsrc, err = handlePatterns(fsys, srcSet, patterns...)
if err != nil {
return false, err
}

View File

@@ -4,56 +4,19 @@ import (
"bufio"
"io"
"io/fs"
"os"
"path/filepath"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testFSDir maps entries' names to entries which should either be a testFSDir
// or byte slice.
type testFSDir map[string]interface{}
// testFSGen is used to generate a temporary filesystem consisting of
// directories and plain text files from itself.
type testFSGen testFSDir
// gen returns the name of top directory of the generated filesystem.
func (g testFSGen) gen(t *testing.T) (dirName string) {
t.Helper()
dirName = t.TempDir()
g.rangeThrough(t, dirName)
return dirName
}
func (g testFSGen) rangeThrough(t *testing.T, dirName string) {
const perm fs.FileMode = 0o777
for k, e := range g {
switch e := e.(type) {
case []byte:
require.NoError(t, os.WriteFile(filepath.Join(dirName, k), e, perm))
case testFSDir:
newDir := filepath.Join(dirName, k)
require.NoError(t, os.Mkdir(newDir, perm))
testFSGen(e).rangeThrough(t, newDir)
default:
t.Fatalf("unexpected entry type %T", e)
}
}
}
func TestFileWalker_Walk(t *testing.T) {
const attribute = `000`
makeFileWalker := func(dirName string) (fw FileWalker) {
makeFileWalker := func(_ string) (fw FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
@@ -63,7 +26,7 @@ func TestFileWalker_Walk(t *testing.T) {
}
if len(line) != 0 {
patterns = append(patterns, filepath.Join(dirName, line))
patterns = append(patterns, path.Join(".", line))
}
}
@@ -74,136 +37,150 @@ func TestFileWalker_Walk(t *testing.T) {
const nl = "\n"
testCases := []struct {
name string
testFS testFSGen
testFS fstest.MapFS
want assert.BoolAssertionFunc
initPattern string
want bool
name string
}{{
name: "simple",
testFS: testFSGen{
"simple_0001.txt": []byte(attribute + nl),
testFS: fstest.MapFS{
"simple_0001.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
},
initPattern: "simple_0001.txt",
want: true,
want: assert.True,
}, {
name: "chain",
testFS: testFSGen{
"chain_0001.txt": []byte(`chain_0002.txt` + nl),
"chain_0002.txt": []byte(`chain_0003.txt` + nl),
"chain_0003.txt": []byte(attribute + nl),
testFS: fstest.MapFS{
"chain_0001.txt": &fstest.MapFile{Data: []byte(`chain_0002.txt` + nl)},
"chain_0002.txt": &fstest.MapFile{Data: []byte(`chain_0003.txt` + nl)},
"chain_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
},
initPattern: "chain_0001.txt",
want: true,
want: assert.True,
}, {
name: "several",
testFS: testFSGen{
"several_0001.txt": []byte(`several_*` + nl),
"several_0002.txt": []byte(`several_0001.txt` + nl),
"several_0003.txt": []byte(attribute + nl),
testFS: fstest.MapFS{
"several_0001.txt": &fstest.MapFile{Data: []byte(`several_*` + nl)},
"several_0002.txt": &fstest.MapFile{Data: []byte(`several_0001.txt` + nl)},
"several_0003.txt": &fstest.MapFile{Data: []byte(attribute + nl)},
},
initPattern: "several_0001.txt",
want: true,
want: assert.True,
}, {
name: "no",
testFS: testFSGen{
"no_0001.txt": []byte(nl),
"no_0002.txt": []byte(nl),
"no_0003.txt": []byte(nl),
testFS: fstest.MapFS{
"no_0001.txt": &fstest.MapFile{Data: []byte(nl)},
"no_0002.txt": &fstest.MapFile{Data: []byte(nl)},
"no_0003.txt": &fstest.MapFile{Data: []byte(nl)},
},
initPattern: "no_*",
want: false,
want: assert.False,
}, {
name: "subdirectory",
testFS: testFSGen{
"dir": testFSDir{
"subdir_0002.txt": []byte(attribute + nl),
testFS: fstest.MapFS{
path.Join("dir", "subdir_0002.txt"): &fstest.MapFile{
Data: []byte(attribute + nl),
},
"subdir_0001.txt": []byte(`dir/*`),
"subdir_0001.txt": &fstest.MapFile{Data: []byte(`dir/*`)},
},
initPattern: "subdir_0001.txt",
want: true,
want: assert.True,
}}
for _, tc := range testCases {
testDir := tc.testFS.gen(t)
fw := makeFileWalker(testDir)
fw := makeFileWalker("")
t.Run(tc.name, func(t *testing.T) {
ok, err := fw.Walk(filepath.Join(testDir, tc.initPattern))
ok, err := fw.Walk(tc.testFS, tc.initPattern)
require.NoError(t, err)
assert.Equal(t, tc.want, ok)
tc.want(t, ok)
})
}
t.Run("pattern_malformed", func(t *testing.T) {
ok, err := makeFileWalker("").Walk("[]")
f := fstest.MapFS{}
ok, err := makeFileWalker("").Walk(f, "[]")
require.Error(t, err)
assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern)
assert.ErrorIs(t, err, path.ErrBadPattern)
})
t.Run("bad_filename", func(t *testing.T) {
dir := testFSGen{
"bad_filename.txt": []byte("[]"),
}.gen(t)
fw := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
const filename = "bad_filename.txt"
f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
patterns = append(patterns, s.Text())
}
return patterns, true, s.Err()
})
ok, err := fw.Walk(filepath.Join(dir, "bad_filename.txt"))
}).Walk(f, filename)
require.Error(t, err)
assert.False(t, ok)
assert.ErrorIs(t, err, filepath.ErrBadPattern)
assert.ErrorIs(t, err, path.ErrBadPattern)
})
t.Run("itself_error", func(t *testing.T) {
const rerr errors.Error = "returned error"
dir := testFSGen{
"mockfile.txt": []byte(`mockdata`),
}.gen(t)
f := fstest.MapFS{
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr
}).Walk(filepath.Join(dir, "*"))
require.Error(t, err)
require.False(t, ok)
}).Walk(f, "*")
require.ErrorIs(t, err, rerr)
assert.ErrorIs(t, err, rerr)
assert.False(t, ok)
})
}
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(nil, "lol")
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
const badPath = "\x00"
_, ok, err := checkFile(nil, badPath)
require.Error(t, err)
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
assert.False(t, ok)
// TODO(e.burkov): Use assert.ErrorsIs within the error from
// less platform-dependent package instead of syscall.EINVAL.
//
// See https://github.com/golang/go/issues/46849 and
// https://github.com/golang/go/issues/30322.
pathErr := &os.PathError{}
require.ErrorAs(t, err, &pathErr)
assert.Equal(t, "open", pathErr.Op)
assert.Equal(t, badPath, pathErr.Path)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

133
internal/aghos/fswatcher.go Normal file
View File

@@ -0,0 +1,133 @@
package aghos
import (
"fmt"
"io"
"io/fs"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify"
)
// event is a convenient alias for an empty struct to signal that watching
// event happened.
type event = struct{}
// FSWatcher tracks all the fyle system events and notifies about those.
//
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
type FSWatcher interface {
io.Closer
// Events should return a read-only channel which notifies about events.
Events() (e <-chan event)
// Add should check if the file named name is accessible and starts tracking
// it.
Add(name string) (err error)
}
// osWatcher tracks the file system provided by the OS.
type osWatcher struct {
// w is the actual notifier that is handled by osWatcher.
w *fsnotify.Watcher
// events is the channel to notify.
events chan event
}
const (
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
// methods.
osWatcherPref = "os watcher"
)
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
// OS and notifies only about writing events.
func NewOSWritesWatcher() (w FSWatcher, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
var watcher *fsnotify.Watcher
watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating watcher: %w", err)
}
fsw := &osWatcher{
w: watcher,
events: make(chan event, 1),
}
go fsw.handleErrors()
go fsw.handleEvents()
return fsw, nil
}
// handleErrors handles accompanying errors. It used to be called in a separate
// goroutine.
func (w *osWatcher) handleErrors() {
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
for err := range w.w.Errors {
log.Error("%s: %s", osWatcherPref, err)
}
}
// Events implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Events() (e <-chan event) {
return w.events
}
// Add implements the FSWatcher interface for *osWatcher.
//
// TODO(e.burkov): Make it accept non-existing files to detect it's creating.
func (w *osWatcher) Add(name string) (err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
if _, err = fs.Stat(RootDirFS(), name); err != nil {
return fmt.Errorf("checking file %q: %w", name, err)
}
return w.w.Add(filepath.Join("/", name))
}
// Close implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Close() (err error) {
return w.w.Close()
}
// handleEvents notifies about the received file system's event if needed. It
// used to be called in a separate goroutine.
func (w *osWatcher) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
defer close(w.events)
ch := w.w.Events
for e := range ch {
if e.Op&fsnotify.Write == 0 {
continue
}
// Skip the following events assuming that sometimes the same event
// occurrs several times.
for ok := true; ok; {
select {
case _, ok = <-ch:
// Go on.
default:
ok = false
}
}
select {
case w.events <- event{}:
// Go on.
default:
log.Debug("%s: events buffer is full", osWatcherPref)
}
}
}

View File

@@ -7,6 +7,8 @@ import (
"bufio"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path"
"runtime"
@@ -89,7 +91,7 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
}
var instNum int
pid, instNum, err = parsePSOutput(stdout, command, except...)
pid, instNum, err = parsePSOutput(stdout, command, except)
if err != nil {
return 0, err
}
@@ -116,16 +118,15 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
}
// parsePSOutput scans the output of ps searching the largest PID of the process
// associated with cmdName ignoring PIDs from ignore. Valid r's line shoud be
// like:
// associated with cmdName ignoring PIDs from ignore. A valid line from
// r should look like these:
//
// 123 ./example-cmd
// 1230 some/base/path/example-cmd
// 3210 example-cmd
//
func parsePSOutput(r io.Reader, cmdName string, ignore ...int) (largest, instNum int, err error) {
func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum int, err error) {
s := bufio.NewScanner(r)
ScanLoop:
for s.Scan() {
fields := strings.Fields(s.Text())
if len(fields) != 2 || path.Base(fields[1]) != cmdName {
@@ -133,16 +134,10 @@ ScanLoop:
}
cur, aerr := strconv.Atoi(fields[0])
if aerr != nil || cur < 0 {
if aerr != nil || cur < 0 || intIn(cur, ignore) {
continue
}
for _, pid := range ignore {
if cur == pid {
continue ScanLoop
}
}
instNum++
if cur > largest {
largest = cur
@@ -155,7 +150,25 @@ ScanLoop:
return largest, instNum, nil
}
// intIn returns true if nums contains n.
func intIn(n int, nums []int) (ok bool) {
for _, nn := range nums {
if n == nn {
return true
}
}
return false
}
// IsOpenWrt returns true if host OS is OpenWrt.
func IsOpenWrt() (ok bool) {
return isOpenWrt()
}
// RootDirFS returns the fs.FS rooted at the operating system's root.
func RootDirFS() (fsys fs.FS) {
// Use empty string since os.DirFS implicitly prepends a slash to it. This
// behavior is undocumented but it currently works.
return os.DirFS("")
}

View File

@@ -26,6 +26,8 @@ func haveAdminRights() (bool, error) {
}
func isOpenWrt() (ok bool) {
const etcReleasePattern = "etc/*release*"
var err error
ok, err = FileWalker(func(r io.Reader) (_ []string, cont bool, err error) {
const osNameData = "openwrt"
@@ -39,7 +41,7 @@ func isOpenWrt() (ok bool) {
}
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
}).Walk("/etc/*release*")
}).Walk(RootDirFS(), etcReleasePattern)
return err == nil && ok
}

View File

@@ -63,7 +63,7 @@ func TestLargestLabeled(t *testing.T) {
r := bytes.NewReader(tc.data)
t.Run(tc.name, func(t *testing.T) {
pid, instNum, err := parsePSOutput(r, comm)
pid, instNum, err := parsePSOutput(r, comm, nil)
require.NoError(t, err)
assert.Equal(t, tc.wantPID, pid)
@@ -76,7 +76,7 @@ func TestLargestLabeled(t *testing.T) {
require.NoError(t, err)
target := &aghio.LimitReachedError{}
_, _, err = parsePSOutput(lr, "")
_, _, err = parsePSOutput(lr, "", nil)
require.ErrorAs(t, err, &target)
assert.EqualValues(t, 0, target.Limit)
@@ -89,7 +89,7 @@ func TestLargestLabeled(t *testing.T) {
`3` + comm + nl,
))
pid, instances, err := parsePSOutput(r, comm, 1, 3)
pid, instances, err := parsePSOutput(r, comm, []int{1, 3})
require.NoError(t, err)
assert.Equal(t, 2, pid)

View File

@@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) {
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
// revert changes.
func ReplaceLogWriter(t *testing.T, w io.Writer) {
stdWriter := log.Writer()
t.Cleanup(func() {
log.SetOutput(stdWriter)
})
func ReplaceLogWriter(t testing.TB, w io.Writer) {
t.Helper()
prev := log.Writer()
t.Cleanup(func() { log.SetOutput(prev) })
log.SetOutput(w)
}
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
// revert changes.
func ReplaceLogLevel(t *testing.T, l log.Level) {
func ReplaceLogLevel(t testing.TB, l log.Level) {
t.Helper()
switch l {
case log.INFO, log.DEBUG, log.ERROR:
// Go on.
@@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) {
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
}
stdLevel := log.GetLevel()
t.Cleanup(func() {
log.SetLevel(stdLevel)
})
prev := log.GetLevel()
t.Cleanup(func() { log.SetLevel(prev) })
log.SetLevel(l)
}

View File

@@ -0,0 +1,23 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@@ -0,0 +1,46 @@
package aghtest
import "io/fs"
// type check
var _ fs.FS = &FS{}
// FS is a mock fs.FS implementation to use in tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the fs.FS interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock fs.StatFS implementation to use in tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the fs.StatFS interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock fs.GlobFS implementation to use in tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the fs.GlobFS interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

View File

@@ -11,10 +11,10 @@ import (
"github.com/miekg/dns"
)
// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Upstream is a mock implementation of upstream.Upstream.
type Upstream struct {
// CName is a map of hostname to canonical name.
CName map[string]string
CName map[string][]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
@@ -25,78 +25,45 @@ type TestUpstream struct {
Addr string
}
// Exchange implements upstream.Upstream interface for *TestUpstream.
// Exchange implements the upstream.Upstream interface for *Upstream.
//
// TODO(a.garipov): Split further into handlers.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = new(dns.Msg).SetReply(m)
if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
name := m.Question[0].Name
if cname, ok := u.CName[name]; ok {
ans := &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
q := m.Question[0]
name := q.Name
for _, cname := range u.CName[name] {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME},
Target: cname,
}
resp.Answer = append(resp.Answer, ans)
})
}
rrType := m.Question[0].Qtype
qtype := q.Qtype
hdr := dns.RR_Header{
Name: name,
Rrtype: rrType,
Rrtype: qtype,
}
var names []string
var ips []net.IP
switch m.Question[0].Qtype {
switch qtype {
case dns.TypeA:
ips = u.IPv4[name]
for _, ip := range u.IPv4[name] {
resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip})
}
case dns.TypeAAAA:
ips = u.IPv6[name]
for _, ip := range u.IPv6[name] {
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
}
case dns.TypePTR:
names = u.Reverse[name]
}
for _, ip := range ips {
var ans dns.RR
if rrType == dns.TypeA {
ans = &dns.A{
Hdr: hdr,
A: ip,
}
resp.Answer = append(resp.Answer, ans)
continue
for _, name := range u.Reverse[name] {
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
}
ans = &dns.AAAA{
Hdr: hdr,
AAAA: ip,
}
resp.Answer = append(resp.Answer, ans)
}
for _, n := range names {
ans := &dns.PTR{
Hdr: hdr,
Ptr: n,
}
resp.Answer = append(resp.Answer, ans)
}
if len(resp.Answer) == 0 {
resp.SetRcode(m, dns.RcodeNameError)
}
@@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}
// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
// Address implements upstream.Upstream interface for *Upstream.
func (u *Upstream) Address() string {
return u.Addr
}

30
internal/aghtls/aghtls.go Normal file
View File

@@ -0,0 +1,30 @@
// Package aghtls contains utilities for work with TLS.
package aghtls
import "crypto/tls"
// SaferCipherSuites returns a set of default cipher suites with vulnerable and
// weak cipher suites removed.
func SaferCipherSuites() (safe []uint16) {
for _, s := range tls.CipherSuites() {
switch s.ID {
case
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
// Less safe 3DES and CBC suites, go on.
default:
safe = append(safe, s.ID)
}
}
return safe
}

View File

@@ -5,9 +5,9 @@ package dhcpd
import (
"net"
"strings"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4"
@@ -45,7 +45,7 @@ func TestDHCPConn_WriteTo_common(t *testing.T) {
n, err := conn.WriteTo(nil, &unexpectedAddrType{})
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "peer is of unexpected type"))
testutil.AssertErrorMsg(t, "peer is of unexpected type *dhcpd.unexpectedAddrType", err)
assert.Zero(t, n)
})
}

View File

@@ -119,23 +119,28 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
return nil
}
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"` // path to DB file
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
// LocalDomainName is the domain name used for DHCP hosts. For example,
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
// when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"`
}
// OnLeaseChangedT is a callback for lease changes.
@@ -156,7 +161,9 @@ type Server struct {
srv4 DHCPServer
srv6 DHCPServer
conf ServerConfig
// TODO(a.garipov): Either create a separate type for the internal config or
// just put the config values into Server.
conf *ServerConfig
// Called when the leases DB is modified
onLeaseChanged []OnLeaseChangedT
@@ -181,14 +188,21 @@ type ServerInterface interface {
}
// Create - create object
func Create(conf ServerConfig) (s *Server, err error) {
s = &Server{}
func Create(conf *ServerConfig) (s *Server, err error) {
s = &Server{
conf: &ServerConfig{
ConfigModified: conf.ConfigModified,
s.conf.Enabled = conf.Enabled
s.conf.InterfaceName = conf.InterfaceName
s.conf.HTTPRegister = conf.HTTPRegister
s.conf.ConfigModified = conf.ConfigModified
s.conf.DBFilePath = filepath.Join(conf.WorkDir, dbFilename)
HTTPRegister: conf.HTTPRegister,
Enabled: conf.Enabled,
InterfaceName: conf.InterfaceName,
LocalDomainName: conf.LocalDomainName,
DBFilePath: filepath.Join(conf.WorkDir, dbFilename),
},
}
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
if runtime.GOOS == "windows" {
@@ -305,6 +319,7 @@ func (s *Server) notify(flags int) {
func (s *Server) WriteDiskConfig(c *ServerConfig) {
c.Enabled = s.conf.Enabled
c.InterfaceName = s.conf.InterfaceName
c.LocalDomainName = s.conf.LocalDomainName
s.srv4.WriteDiskConfig4(&c.Conf4)
s.srv6.WriteDiskConfig6(&c.Conf6)
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -26,7 +27,7 @@ func testNotify(flags uint32) {
func TestDB(t *testing.T) {
var err error
s := Server{
conf: ServerConfig{
conf: &ServerConfig{
DBFilePath: dbFilename,
},
}
@@ -67,9 +68,7 @@ func TestDB(t *testing.T) {
err = s.dbStore()
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, os.Remove(dbFilename))
})
testutil.CleanupAndRequireSuccess(t, func() (err error) { return os.Remove(dbFilename) })
err = s.srv4.ResetLeases(nil)
require.NoError(t, err)
@@ -138,6 +137,49 @@ func TestNormalizeLeases(t *testing.T) {
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
}
func TestV4Server_badRange(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
gatewayIP net.IP
subnetMask net.IP
}{{
name: "gateway_in_range",
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
"192.168.10.20-192.168.10.200",
gatewayIP: net.IP{192, 168, 10, 120},
subnetMask: net.IP{255, 255, 255, 0},
}, {
name: "outside_range_start",
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
"192.168.10.1/28",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 240},
}, {
name: "outside_range_end",
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
"192.168.10.1/27",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 224},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 20},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: tc.gatewayIP,
SubnetMask: tc.subnetMask,
notify: testNotify,
}
_, err := v4Create(conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
// cloneUDPAddr returns a deep copy of a.
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
return &net.UDPAddr{

View File

@@ -8,18 +8,15 @@ import (
"net/http"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DHCP: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type v4ServerConfJSON struct {
GatewayIP net.IP `json:"gateway_ip"`
SubnetMask net.IP `json:"subnet_mask"`
@@ -85,8 +82,13 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
return
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal DHCP status json: %s",
err,
)
}
}
@@ -209,36 +211,34 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(conf)
if err != nil {
httpError(r, w, http.StatusBadRequest,
"failed to parse new dhcp config json: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err)
return
}
srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv4 configuration: %s", err)
return
}
srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "bad dhcpv6 configuration: %s", err)
return
}
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
httpError(r, w, http.StatusBadRequest,
"dhcpv4 or dhcpv6 configuration must be complete")
aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
return
}
err = s.Stop()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return
}
@@ -263,7 +263,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err = s.dbLoad()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
return
}
@@ -272,9 +272,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
var code int
code, err = s.enableDHCP(conf.InterfaceName)
if err != nil {
httpError(r, w, code, "enabling dhcp: %s", err)
return
aghhttp.Error(r, w, code, "enabling dhcp: %s", err)
}
}
}
@@ -293,7 +291,8 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ifaces, err := net.Interfaces()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
@@ -310,7 +309,15 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
var addrs []net.Addr
addrs, err = iface.Addrs()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to get addresses for interface %s: %s",
iface.Name,
err,
)
return
}
@@ -327,7 +334,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ipnet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"got iface.Addrs() element %[1]s that is not net.IPNet, it is %[1]T",
addr)
return
}
// ignore link-local
@@ -348,8 +361,13 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
return
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to marshal json with available interfaces: %s",
err,
)
}
}
@@ -453,9 +471,13 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
return
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to marshal DHCP found json: %s",
err,
)
}
}
@@ -463,13 +485,13 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if l.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
}
@@ -481,7 +503,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
err = s.srv6.AddStaticLease(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
return
@@ -490,7 +512,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
l.IP = ip4
err = s.srv4.AddStaticLease(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -500,13 +522,13 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if l.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
}
@@ -518,7 +540,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
err = s.srv6.RemoveStaticLease(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
return
@@ -527,16 +549,23 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
l.IP = ip4
err = s.srv4.RemoveStaticLease(l)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
const (
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
DefaultDHCPTimeoutICMP = 1000
)
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
err := s.Stop()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return
}
@@ -546,20 +575,28 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
log.Error("dhcp: removing db: %s", err)
}
oldconf := s.conf
s.conf = ServerConfig{}
s.conf.WorkDir = oldconf.WorkDir
s.conf.HTTPRegister = oldconf.HTTPRegister
s.conf.ConfigModified = oldconf.ConfigModified
s.conf.DBFilePath = oldconf.DBFilePath
s.conf = &ServerConfig{
ConfigModified: s.conf.ConfigModified,
v4conf := V4ServerConf{}
v4conf.ICMPTimeout = 1000
v4conf.notify = s.onNotify
HTTPRegister: s.conf.HTTPRegister,
LocalDomainName: s.conf.LocalDomainName,
WorkDir: s.conf.WorkDir,
DBFilePath: s.conf.DBFilePath,
}
v4conf := V4ServerConf{
LeaseDuration: DefaultDHCPLeaseTTL,
ICMPTimeout: DefaultDHCPTimeoutICMP,
notify: s.onNotify,
}
s.srv4, _ = v4Create(v4conf)
v6conf := V6ServerConf{}
v6conf.notify = s.onNotify
v6conf := V6ServerConf{
LeaseDuration: DefaultDHCPLeaseTTL,
notify: s.onNotify,
}
s.srv6, _ = v6Create(v6conf)
s.conf.ConfigModified()
@@ -569,7 +606,7 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
err := s.resetLeases()
if err != nil {
msg := "resetting leases: %s"
httpError(r, w, http.StatusInternalServerError, msg, err)
aghhttp.Error(r, w, http.StatusInternalServerError, msg, err)
return
}

View File

@@ -15,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) {
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
require.Nil(t, err)
require.NoError(t, err)
h(w, r)
assert.Equal(t, http.StatusNotImplemented, w.Code)

View File

@@ -14,8 +14,8 @@ import (
//
// It is safe for concurrent use.
//
// TODO(a.garipov): Perhaps create an optimised version with uint32 for
// IPv4 ranges? Or use one of uint128 packages?
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
// ranges? Or use one of uint128 packages?
type ipRange struct {
start *big.Int
end *big.Int

View File

@@ -4,6 +4,7 @@ import (
"net"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -64,14 +65,8 @@ func TestNewIPRange(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r, err := newIPRange(tc.start, tc.end)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
assert.NotNil(t, r)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
_, err := newIPRange(tc.start, tc.end)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -11,33 +12,33 @@ import (
func TestNullBool_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
data []byte
wantErrMsg string
data []byte
want nullBool
}{{
name: "empty",
data: []byte{},
wantErrMsg: "",
data: []byte{},
want: nbNull,
}, {
name: "null",
data: []byte("null"),
wantErrMsg: "",
data: []byte("null"),
want: nbNull,
}, {
name: "true",
data: []byte("true"),
wantErrMsg: "",
data: []byte("true"),
want: nbTrue,
}, {
name: "false",
data: []byte("false"),
wantErrMsg: "",
data: []byte("false"),
want: nbFalse,
}, {
name: "invalid",
data: []byte("flase"),
wantErrMsg: `invalid nullBool value "flase"`,
wantErrMsg: `invalid nullBool value "invalid"`,
data: []byte("invalid"),
want: nbNull,
}}
@@ -45,13 +46,7 @@ func TestNullBool_UnmarshalJSON(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var got nullBool
err := got.UnmarshalJSON(tc.data)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})

View File

@@ -130,21 +130,19 @@ func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
// prepareOptions builds the set of DHCP options according to host requirements
// document and values from conf.
func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
// Set default values for host configuration parameters listed in Appendix
// A of RFC-2131. Those parameters, if requested by client, should be
// returned with values defined by Host Requirements Document.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
//
// See also https://datatracker.ietf.org/doc/html/rfc1122,
// https://datatracker.ietf.org/doc/html/rfc1123, and
// https://datatracker.ietf.org/doc/html/rfc2132.
opts = dhcpv4.Options{
// Set default values for host configuration parameters listed
// in Appendix A of RFC-2131. Those parameters, if requested by
// client, should be returned with values defined by Host
// Requirements Document.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
//
// See also https://datatracker.ietf.org/doc/html/rfc1122,
// https://datatracker.ietf.org/doc/html/rfc1123, and
// https://datatracker.ietf.org/doc/html/rfc2132.
// IP-Layer Per Host
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
// Set the current recommended default time to live for the
// Internet Protocol which is 64, see
// https://datatracker.ietf.org/doc/html/rfc1700.

View File

@@ -95,6 +95,7 @@ func TestParseOpt(t *testing.T) {
opt, err := parseDHCPOption(tc.in)
if tc.wantErrMsg != "" {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
return

View File

@@ -27,7 +27,6 @@ type DHCPServer interface {
Start() (err error)
// Stop - stop server
Stop() (err error)
getLeasesRef() []*Lease
}

View File

@@ -293,6 +293,8 @@ func (s *v4Server) addLease(l *Lease) (err error) {
offset, inOffset := r.offset(l.IP)
if l.IsStatic() {
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
// disabled.
if sn := s.conf.subnet; !sn.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
}
@@ -900,9 +902,10 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
}
}
// Update the value of Domain Name Server option separately from others
// since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) {
// Update the value of Domain Name Server option separately from others if
// not assigned yet since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) &&
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
}
@@ -966,11 +969,10 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
Port: dhcpv4.ServerPort,
}
if mtype == dhcpv4.MessageTypeNak {
// Set the broadcast bit in the DHCPNAK, so that the
// relay agent broadcasted it to the client, because the
// client may not have a correct network address or
// subnet mask, and the client may not be answering ARP
// requests.
// Set the broadcast bit in the DHCPNAK, so that the relay agent
// broadcasts it to the client, because the client may not have
// a correct network address or subnet mask, and the client may not
// be answering ARP requests.
resp.SetBroadcast()
}
case mtype == dhcpv4.MessageTypeNak:
@@ -1053,8 +1055,6 @@ func (s *v4Server) Start() (err error) {
go func() {
if serr := s.srv.Serve(); errors.Is(serr, net.ErrClosed) {
log.Info("dhcpv4: server is closed")
return
} else if serr != nil {
log.Error("dhcpv4: srv.Serve: %s", serr)
}
@@ -1124,6 +1124,29 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
return s, fmt.Errorf("dhcpv4: %w", err)
}
if s.conf.ipRange.contains(routerIP) {
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
routerIP,
conf.RangeStart,
conf.RangeEnd,
)
}
if !s.conf.subnet.Contains(conf.RangeStart) {
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
conf.RangeStart,
s.conf.subnet,
)
}
if !s.conf.subnet.Contains(conf.RangeEnd) {
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
conf.RangeEnd,
s.conf.subnet,
)
}
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
s.leasedOffsets = newBitSet()
if conf.LeaseDuration == 0 {

View File

@@ -5,8 +5,11 @@ package dhcpd
import (
"net"
"strings"
"testing"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
@@ -16,17 +19,34 @@ import (
func notify4(flags uint32) {
}
func TestV4_AddRemove_static(t *testing.T) {
s, err := v4Create(V4ServerConf{
// defaultV4ServerConf returns the default configuration for *v4Server to use in
// tests.
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
}
}
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
// type of s is *v4Server.
func defaultSrv(t *testing.T) (s DHCPServer) {
t.Helper()
var err error
s, err = v4Create(defaultV4ServerConf())
require.NoError(t, err)
return s
}
func TestV4_AddRemove_static(t *testing.T) {
s := defaultSrv(t)
ls := s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
@@ -37,7 +57,7 @@ func TestV4_AddRemove_static(t *testing.T) {
IP: net.IP{192, 168, 10, 150},
}
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
err = s.AddStaticLease(l)
@@ -65,15 +85,7 @@ func TestV4_AddRemove_static(t *testing.T) {
}
func TestV4_AddReplace(t *testing.T) {
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
require.NoError(t, err)
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
@@ -89,7 +101,7 @@ func TestV4_AddReplace(t *testing.T) {
}}
for i := range dynLeases {
err = s.addLease(&dynLeases[i])
err := s.addLease(&dynLeases[i])
require.NoError(t, err)
}
@@ -104,7 +116,7 @@ func TestV4_AddReplace(t *testing.T) {
}}
for _, l := range stLeases {
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
}
@@ -118,17 +130,80 @@ func TestV4_AddReplace(t *testing.T) {
}
}
func TestV4StaticLease_Get(t *testing.T) {
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
func TestV4Server_Process_optionsPriority(t *testing.T) {
defaultIP := net.IP{192, 168, 1, 1}
knownIP := net.IP{1, 2, 3, 4}
// prepareSrv creates a *v4Server and sets the opt6IPs in the initial
// configuration of the server as the value for DHCP option 6.
prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
t.Helper()
conf := defaultV4ServerConf()
if len(opt6IPs) > 0 {
b := &strings.Builder{}
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
for _, ip := range opt6IPs[1:] {
stringutil.WriteToBuilder(b, ",", ip.String())
}
conf.Options = []string{b.String()}
}
ss, err := v4Create(conf)
require.NoError(t, err)
var ok bool
s, ok = ss.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{defaultIP}
return s
}
// checkResp creates a discovery message with DHCP option 6 requested amd
// asserts the response to contain wantIPs in this option.
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
t.Helper()
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionDomainNameServer,
))
require.NoError(t, err)
var resp *dhcpv4.DHCPv4
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
res := s.process(req, resp)
require.Equal(t, 1, res)
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
require.NotEmpty(t, o)
wantData := []byte{}
for _, ip := range wantIPs {
wantData = append(wantData, ip...)
}
assert.Equal(t, o, wantData)
}
t.Run("default", func(t *testing.T) {
s := prepareSrv(t, nil)
checkResp(t, s, []net.IP{defaultIP})
})
require.NoError(t, err)
t.Run("explicitly_configured", func(t *testing.T) {
s := prepareSrv(t, []net.IP{knownIP, knownIP})
checkResp(t, s, []net.IP{knownIP, knownIP})
})
}
func TestV4StaticLease_Get(t *testing.T) {
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
@@ -140,7 +215,7 @@ func TestV4StaticLease_Get(t *testing.T) {
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}
err = s.AddStaticLease(l)
err := s.AddStaticLease(l)
require.NoError(t, err)
var req, resp *dhcpv4.DHCPv4
@@ -208,19 +283,14 @@ func TestV4StaticLease_Get(t *testing.T) {
}
func TestV4DynamicLease_Get(t *testing.T) {
conf := defaultV4ServerConf()
conf.Options = []string{
"81 hex 303132",
"82 ip 1.2.3.4",
}
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
Options: []string{
"81 hex 303132",
"82 ip 1.2.3.4",
},
})
sIface, err := v4Create(conf)
require.NoError(t, err)
s, ok := sIface.(*v4Server)
@@ -361,14 +431,7 @@ func TestNormalizeHostname(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := normalizeHostname(tc.hostname)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}

View File

@@ -662,8 +662,11 @@ func (s *v6Server) Start() (err error) {
}
go func() {
err = s.srv.Serve()
log.Error("dhcpv6: srv.Serve: %s", err)
if serr := s.srv.Serve(); errors.Is(serr, net.ErrClosed) {
log.Info("dhcpv6: server is closed")
} else if serr != nil {
log.Error("dhcpv6: srv.Serve: %s", serr)
}
}()
return nil

View File

@@ -7,7 +7,8 @@ import (
"net/http"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
@@ -118,8 +119,8 @@ func (a *accessCtx) allowlistMode() (ok bool) {
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
allowlistMode := a.allowlistMode()
if id == "" {
// In allowlist mode, consider requests without client IDs
// blocked by default.
// In allowlist mode, consider requests without ClientIDs blocked by
// default.
return allowlistMode
}
@@ -187,78 +188,62 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
return
}
}
func isUniq(slice []string) (ok bool, uniqueMap map[string]unit) {
exists := make(map[string]unit)
for _, key := range slice {
if _, has := exists[key]; has {
return false, nil
}
exists[key] = unit{}
}
return true, exists
}
func intersect(mapA, mapB map[string]unit) bool {
for key := range mapA {
if _, has := mapB[key]; has {
return true
}
}
return false
}
// validateAccessSet checks the internal accessListJSON lists. To search for
// duplicates, we cannot compare the new stringutil.Set and []string, because
// creating a set for a large array can be an unnecessary algorithmic complexity
func validateAccessSet(list accessListJSON) (err error) {
const (
errAllowedDup errors.Error = "duplicates in allowed clients"
errDisallowedDup errors.Error = "duplicates in disallowed clients"
errBlockedDup errors.Error = "duplicates in blocked hosts"
errIntersect errors.Error = "some items in allowed and " +
"disallowed lists at the same time"
)
ok, allowedClients := isUniq(list.AllowedClients)
if !ok {
return errAllowedDup
func validateAccessSet(list *accessListJSON) (err error) {
allowed, err := validateStrUniq(list.AllowedClients)
if err != nil {
return fmt.Errorf("validating allowed clients: %w", err)
}
ok, disallowedClients := isUniq(list.DisallowedClients)
if !ok {
return errDisallowedDup
disallowed, err := validateStrUniq(list.DisallowedClients)
if err != nil {
return fmt.Errorf("validating disallowed clients: %w", err)
}
ok, _ = isUniq(list.BlockedHosts)
if !ok {
return errBlockedDup
_, err = validateStrUniq(list.BlockedHosts)
if err != nil {
return fmt.Errorf("validating blocked hosts: %w", err)
}
if intersect(allowedClients, disallowedClients) {
return errIntersect
merged := allowed.Merge(disallowed)
err = merged.Validate(aghalg.StringIsBefore)
if err != nil {
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
}
return nil
}
// validateStrUniq returns an informative error if clients are not unique.
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
uc = make(aghalg.UniqChecker, len(clients))
for _, c := range clients {
uc.Add(c)
}
return uc, uc.Validate(aghalg.StringIsBefore)
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
list := accessListJSON{}
list := &accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
err = validateAccessSet(list)
if err != nil {
httpError(r, w, http.StatusBadRequest, err.Error())
aghhttp.Error(r, w, http.StatusBadRequest, err.Error())
return
}
@@ -266,7 +251,7 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
var a *accessCtx
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil {
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return
}

View File

@@ -2,7 +2,6 @@ package dnsforward
import (
"crypto/tls"
"encoding/binary"
"fmt"
"path"
"strings"
@@ -13,12 +12,12 @@ import (
"github.com/lucas-clemente/quic-go"
)
// ValidateClientID returns an error if clientID is not a valid client ID.
func ValidateClientID(clientID string) (err error) {
err = netutil.ValidateDomainNameLabel(clientID)
// ValidateClientID returns an error if id is not a valid ClientID.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateDomainNameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err))
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
@@ -34,7 +33,7 @@ func hasLabelSuffix(s, suffix string) (ok bool) {
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
}
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
// clientIDFromClientServerName will return an error.
@@ -87,22 +86,22 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
}
if len(parts) == 0 || parts[0] != "dns-query" {
return "", fmt.Errorf("client id check: invalid path %q", origPath)
return "", fmt.Errorf("clientid check: invalid path %q", origPath)
}
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
// Just /dns-query, no ClientID.
return "", nil
case 2:
clientID = parts[1]
default:
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
}
err = ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("client id check: %w", err)
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
@@ -167,24 +166,8 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("client id check: %w", err)
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// processClientID puts the clientID into the DNS context, if there is one.
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
var key [8]byte
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
clientIDData := s.clientIDCache.Get(key[:])
if clientIDData == nil {
return resultCodeSuccess
}
dctx.clientID = string(clientIDData)
return resultCodeSuccess
}

View File

@@ -8,9 +8,9 @@ import (
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/testutil"
"github.com/lucas-clemente/quic-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTLSConn is a tlsConn for tests.
@@ -31,8 +31,8 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession
// a quic.Session without acctually implementing all methods.
// Session is embedded here simply to make testQUICSession a quic.Session
// without actually implementing all methods.
quic.Session
serverName string
@@ -65,7 +65,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: "",
strictSNI: false,
}, {
name: "tls_no_client_id",
name: "tls_no_clientid",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "example.com",
@@ -78,7 +78,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
hostSrvName: "example.com",
cliSrvName: "",
wantClientID: "",
wantErrMsg: `client id check: client server name "" ` +
wantErrMsg: `clientid check: client server name "" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}, {
@@ -90,7 +90,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: "",
strictSNI: false,
}, {
name: "tls_client_id",
name: "tls_clientid",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
@@ -98,36 +98,36 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: "",
strictSNI: true,
}, {
name: "tls_client_id_hostname_error",
name: "tls_clientid_hostname_error",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.net",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" ` +
wantErrMsg: `clientid check: client server name "cli.example.net" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}, {
name: "tls_invalid_client_id",
name: "tls_invalid_clientid",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
wantErrMsg: `clientid check: invalid clientid "!!!": ` +
`bad domain name label rune '!'`,
strictSNI: true,
}, {
name: "tls_client_id_too_long",
name: "tls_clientid_too_long",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
`pqrstuvwxyz0123456789.example.com`,
wantClientID: "",
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
wantErrMsg: `clientid check: invalid clientid "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
`domain name label is too long: got 72, max 63`,
strictSNI: true,
}, {
name: "quic_client_id",
name: "quic_clientid",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
@@ -135,12 +135,12 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: "",
strictSNI: true,
}, {
name: "tls_client_id_issue3437",
name: "tls_clientid_issue3437",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.myexample.com",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.myexample.com" ` +
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}}
@@ -179,13 +179,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
clientID, err := srv.clientIDFromDNSContext(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
@@ -197,22 +191,22 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
wantClientID string
wantErrMsg string
}{{
name: "no_client_id",
name: "no_clientid",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
}, {
name: "no_client_id_slash",
name: "no_clientid_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
}, {
name: "client_id",
name: "clientid",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
}, {
name: "client_id_slash",
name: "clientid_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
@@ -220,18 +214,17 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantErrMsg: `clientid check: invalid path "/foo"`,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantErrMsg: `clientid check: invalid path "/dns-query/cli/foo": extra parts`,
}, {
name: "invalid_client_id",
name: "invalid_clientid",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`bad domain name label rune '!'`,
wantErrMsg: `clientid check: invalid clientid "!!!": bad domain name label rune '!'`,
}}
for _, tc := range testCases {
@@ -250,13 +243,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
clientID, err := clientIDFromDNSContextHTTPS(pctx)
assert.Equal(t, tc.wantClientID, clientID)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
@@ -98,10 +99,10 @@ type FilteringConfig struct {
AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients
DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked
BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked
// TrustedProxies is the list of IP addresses and CIDR networks to
// detect proxy servers addresses the DoH requests from which should be
// handled. The value of nil or an empty slice for this field makes
// Proxy not trust any address.
// TrustedProxies is the list of IP addresses and CIDR networks to detect
// proxy servers addresses the DoH requests from which should be handled.
// The value of nil or an empty slice for this field makes Proxy not trust
// any address.
TrustedProxies []string `yaml:"trusted_proxies"`
// DNS cache settings
@@ -118,7 +119,7 @@ type FilteringConfig struct {
BogusNXDomain []string `yaml:"bogus_nxdomain"` // transform responses with these IP addresses to NXDOMAIN
AAAADisabled bool `yaml:"aaaa_disabled"` // Respond with an empty answer to all AAAA requests
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set DNSSEC flag in outcoming DNS request
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
@@ -149,8 +150,8 @@ type TLSConfig struct {
CertificateChainData []byte `yaml:"-" json:"-"`
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only
// being used for client ID checking.
// ServerName is the hostname of the server. Currently, it is only being
// used for ClientID checking.
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
@@ -243,15 +244,15 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
proxyConfig.FastestPingTimeout = s.conf.FastestTimeout.Duration
}
if len(s.conf.BogusNXDomain) > 0 {
for _, s := range s.conf.BogusNXDomain {
ip := net.ParseIP(s)
if ip == nil {
log.Error("Invalid bogus IP: %s", s)
} else {
proxyConfig.BogusNXDomain = append(proxyConfig.BogusNXDomain, ip)
}
for i, s := range s.conf.BogusNXDomain {
subnet, err := netutil.ParseSubnet(s)
if err != nil {
log.Error("subnet at index %d: %s", i, err)
continue
}
proxyConfig.BogusNXDomain = append(proxyConfig.BogusNXDomain, subnet)
}
// TLS settings
@@ -426,6 +427,7 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
proxyConfig.TLSConfig = &tls.Config{
GetCertificate: s.onGetCertificate,
CipherSuites: aghtls.SaferCipherSuites(),
MinVersion: tls.VersionTLS12,
}

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"encoding/binary"
"net"
"strings"
"time"
@@ -16,36 +17,45 @@ import (
// To transfer information between modules
type dnsContext struct {
// TODO(a.garipov): Remove this and rewrite processors to be methods of
// *Server instead.
srv *Server
proxyCtx *proxy.DNSContext
// setts are the filtering settings for the client.
setts *filtering.Settings
startTime time.Time
result *filtering.Result
setts *filtering.Settings
result *filtering.Result
// origResp is the response received from upstream. It is set when the
// response is modified by filters.
origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from PTR request if it
// was successfully parsed.
// parsed successfully and belongs to one of locally-served IP ranges as per
// RFC 6303.
unreversedReqIP net.IP
// err is the error returned from a processing function.
err error
// clientID is the clientID from DoH, DoQ, or DoT, if provided.
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
clientID string
// origQuestion is the question received from the client. It is set
// when the request is modified by rewrites.
origQuestion dns.Question
// startTime is the time at which the processing of the request has started.
startTime time.Time
// protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready.
protectionEnabled bool
// responseFromUpstream shows if the response is received from the
// upstream servers.
responseFromUpstream bool
// origReqDNSSEC shows if the DNSSEC flag in the original request from
// the client is set.
origReqDNSSEC bool
// responseAD shows if the response had the AD bit set.
responseAD bool
// isLocalClient shows if client's IP address is from locally-served
// network.
isLocalClient bool
@@ -69,7 +79,6 @@ const (
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{
srv: s,
proxyCtx: d,
result: &filtering.Result{},
startTime: time.Now(),
@@ -84,19 +93,17 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
// appropriate handler.
mods := []modProcessFunc{
s.processRecursion,
processInitial,
s.processInitial,
s.processDetermineLocal,
s.processInternalHosts,
s.processRestrictLocal,
s.processInternalIPAddrs,
s.processClientID,
processFilteringBeforeRequest,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
processDNSSECAfterResponse,
processFilteringAfterResponse,
s.processFilteringAfterResponse,
s.ipset.process,
processQueryLogsAndStats,
s.processQueryLogsAndStats,
}
for _, process := range mods {
r := process(ctx)
@@ -134,9 +141,11 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
// processInitial terminates the following processing for some requests if
// needed and enriches the ctx with some client-specific information.
//
// TODO(e.burkov): Decompose into less general processors.
func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true)
@@ -155,6 +164,16 @@ func processInitial(ctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
// Get the client's ID if any. It should be performed before getting
// client-specific filtering settings.
var key [8]byte
binary.BigEndian.PutUint64(key[:], d.RequestID)
ctx.clientID = string(s.clientIDCache.Get(key[:]))
// Get the client-specific filtering settings.
ctx.protectionEnabled = s.conf.ProtectionEnabled
ctx.setts = s.getClientRequestFilteringSettings(ctx)
return resultCodeSuccess
}
@@ -196,9 +215,8 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
ipToHost = netutil.NewIPMap(len(ll))
for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished
// with the client hostname validations in the DHCP
// server code.
// TODO(a.garipov): Remove this after we're finished with the client
// hostname validations in the DHCP server code.
err = netutil.ValidateDomainName(l.Hostname)
if err != nil {
log.Debug(
@@ -274,14 +292,16 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
req := dctx.proxyCtx.Req
q := req.Question[0]
// Go on processing the AAAA request despite the fact that we don't
// support it yet. The expected behavior here is to respond with an
// empty asnwer and not NXDOMAIN.
// Go on processing the AAAA request despite the fact that we don't support
// it yet. The expected behavior here is to respond with an empty answer
// and not NXDOMAIN.
if q.Qtype != dns.TypeA && q.Qtype != dns.TypeAAAA {
return resultCodeSuccess
}
reqHost := strings.ToLower(q.Name)
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
// server.
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
if host == reqHost {
return resultCodeSuccess
@@ -298,8 +318,8 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
ip, ok := s.hostToIP(host)
if !ok {
// TODO(e.burkov): Inspect special cases when user want to apply
// some rules handled by other processors to the hosts with TLD.
// TODO(e.burkov): Inspect special cases when user want to apply some
// rules handled by other processors to the hosts with TLD.
d.Res = s.genNXDomain(req)
return resultCodeFinish
@@ -333,34 +353,51 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
ip, err := netutil.IPFromReversedAddr(q.Name)
if err != nil {
log.Debug("dns: reversed addr: %s", err)
log.Debug("dns: parsing reversed addr: %s", err)
return resultCodeError
// DNS-Based Service Discovery uses PTR records having not an ARPA
// format of the domain name in question. Those shouldn't be
// invalidated. See http://www.dns-sd.org/ServerStaticSetup.html and
// RFC 2782.
name := strings.TrimSuffix(q.Name, ".")
if err = netutil.ValidateSRVDomainName(name); err != nil {
log.Debug("dns: validating service domain: %s", err)
return resultCodeError
}
log.Debug("dns: request is for a service domain")
return resultCodeSuccess
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at
// least don't need to be inaccessible externally.
if s.subnetDetector.IsLocallyServedNetwork(ip) {
if !ctx.isLocalClient {
log.Debug("dns: %q requests for internal ip", d.Addr)
d.Res = s.genNXDomain(req)
// assume that all the DHCP leases we give are locally-served or at least
// don't need to be inaccessible externally.
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
log.Debug("dns: addr %s is not from locally-served network", ip)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
if !ctx.isLocalClient {
log.Debug("dns: %q requests an internal ip", d.Addr)
d.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
ctx.unreversedReqIP = ip
// Disable redundant filtering.
filterSetts := s.getClientRequestFilteringSettings(ctx)
filterSetts.ParentalEnabled = false
filterSetts.SafeBrowsingEnabled = false
filterSetts.SafeSearchEnabled = false
filterSetts.ServicesRules = nil
ctx.setts = filterSetts
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally-served ARPA
// hostname so disable redundant filters.
ctx.setts.ParentalEnabled = false
ctx.setts.SafeBrowsingEnabled = false
ctx.setts.SafeSearchEnabled = false
ctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
@@ -468,29 +505,21 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
if ctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if !ctx.protectionEnabled {
if s.dnsFilter == nil {
return resultCodeSuccess
}
if ctx.setts == nil {
ctx.setts = s.getClientRequestFilteringSettings(ctx)
}
var err error
ctx.result, err = s.filterDNSRequest(ctx)
if err != nil {
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
return resultCodeError
@@ -509,147 +538,105 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
}
// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
if pctx.Res != nil {
// The response has already been set.
return resultCodeSuccess
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
// Use the clientID first, since it has a higher priority.
id := stringutil.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
if pctx.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
// Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(dctx.clientID, ipStringFromAddr(pctx.Addr))
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
if err != nil {
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
} else if upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", id)
d.CustomUpstreamConfig = upsConf
pctx.CustomUpstreamConfig = upsConf
}
}
req := d.Req
req := pctx.Req
origReqAD := false
if s.conf.EnableDNSSEC {
opt := req.IsEdns0()
if opt == nil {
log.Debug("dns: adding OPT record with DNSSEC flag")
req.SetEdns0(4096, true)
} else if !opt.Do() {
opt.SetDo(true)
if req.AuthenticatedData {
origReqAD = true
} else {
ctx.origReqDNSSEC = true
req.AuthenticatedData = true
}
}
// Process the request further since it wasn't filtered.
prx := s.proxy()
if prx == nil {
ctx.err = srvClosedErr
dctx.err = srvClosedErr
return resultCodeError
}
if ctx.err = prx.Resolve(d); ctx.err != nil {
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
return resultCodeError
}
ctx.responseFromUpstream = true
dctx.responseFromUpstream = true
dctx.responseAD = pctx.Res.AuthenticatedData
return resultCodeSuccess
}
// Process DNSSEC after response from upstream server
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
!ctx.srv.conf.EnableDNSSEC {
return resultCodeSuccess
}
if !ctx.origReqDNSSEC {
optResp := d.Res.IsEdns0()
if optResp != nil && !optResp.Do() {
return resultCodeSuccess
}
// Remove RRSIG records from response
// because there is no DO flag in the original request from client,
// but we have EnableDNSSEC set, so we have set DO flag ourselves,
// and now we have to clean up the DNS records our client didn't ask for.
answers := []dns.RR{}
for _, a := range d.Res.Answer {
switch a.(type) {
case *dns.RRSIG:
log.Debug("Removing RRSIG record from response: %v", a)
default:
answers = append(answers, a)
}
}
d.Res.Answer = answers
answers = []dns.RR{}
for _, a := range d.Res.Ns {
switch a.(type) {
case *dns.RRSIG:
log.Debug("Removing RRSIG record from response: %v", a)
default:
answers = append(answers, a)
}
}
d.Res.Ns = answers
if s.conf.EnableDNSSEC && !origReqAD {
pctx.Req.AuthenticatedData = false
pctx.Res.AuthenticatedData = false
}
return resultCodeSuccess
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason {
case filtering.Rewritten,
switch res := ctx.result; res.Reason {
case filtering.NotFilteredAllowList:
// Go on.
case
filtering.Rewritten,
filtering.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table
// origQuestion is set in case we get only CNAME without IP from
// rewrites table.
break
}
d.Req.Question[0] = ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
if len(d.Res.Answer) > 0 {
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
d.Res.Answer = answer
}
case filtering.NotFilteredAllowList:
// nothing
default:
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for
!ctx.responseFromUpstream { // only check response if it's from an upstream server
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break
}
origResp2 := d.Res
ctx.result, err = s.filterDNSResponse(ctx)
origResp := d.Res
result, err := s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
} else {
ctx.result = &filtering.Result{}
if result != nil {
ctx.result = result
ctx.origResp = origResp
}
}
if ctx.result == nil {
ctx.result = &filtering.Result{}
}
return resultCodeSuccess
}

View File

@@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
}
func TestServer_ProcessRestrictLocal(t *testing.T) {
ups := &aghtest.TestUpstream{
ups := &aghtest.Upstream{
Reverse: map[string][]string{
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
@@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
}, &aghtest.TestUpstream{
}, &aghtest.Upstream{
Reverse: map[string][]string{
reqAddr: {locDomain},
},

View File

@@ -28,7 +28,7 @@ import (
// DefaultTimeout is the default upstream timeout
const DefaultTimeout = 10 * time.Second
// defaultClientIDCacheCount is the default count of items in the LRU client ID
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
// cache. The assumption here is that there won't be more than this many
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
@@ -79,14 +79,17 @@ type Server struct {
sysResolvers aghnet.SystemResolvers
recDetector *recursionDetector
// anonymizer masks the client's IP addresses if needed.
anonymizer *aghnet.IPMut
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
tableIPToHost *netutil.IPMap
tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for clientIDs that were
// extracted during the BeforeRequestHandler stage.
// clientIDCache is a temporary storage for ClientIDs that were extracted
// during the BeforeRequestHandler stage.
clientIDCache cache.Cache
// DNS proxy instance for internal usage
@@ -113,6 +116,7 @@ type DNSCreateParams struct {
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
SubnetDetector *aghnet.SubnetDetector
Anonymizer *aghnet.IPMut
LocalDomain string
}
@@ -150,6 +154,9 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
}
if p.Anonymizer == nil {
p.Anonymizer = aghnet.NewIPMut(nil)
}
s = &Server{
dnsFilter: p.DNSFilter,
stats: p.Stats,
@@ -161,6 +168,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
EnableLRU: true,
MaxCount: defaultClientIDCacheCount,
}),
anonymizer: p.Anonymizer,
}
// TODO(e.burkov): Enable the refresher after the actual implementation
@@ -435,7 +443,7 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
&upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's ceritificates?
// TODO(e.burkov): Should we verify server's certificates?
},
)
if err != nil {
@@ -551,7 +559,7 @@ func (s *Server) IsRunning() bool {
return s.isRunning
}
// srvClosedErr is returned when the method can't complete without unacessible
// srvClosedErr is returned when the method can't complete without inaccessible
// data from the closing server.
const srvClosedErr errors.Error = "server is closed"

View File

@@ -11,9 +11,10 @@ import (
"fmt"
"math/big"
"net"
"os"
"sync"
"sync/atomic"
"testing"
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@@ -23,6 +24,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -44,10 +46,7 @@ func startDeferStop(t *testing.T, s *Server) {
err := s.Start()
require.NoErrorf(t, err, "failed to start server: %s", err)
t.Cleanup(func() {
serr := s.Stop()
require.NoErrorf(t, serr, "dns server failed to stop: %s", serr)
})
testutil.CleanupAndRequireSuccess(t, s.Stop)
}
func createTestServer(
@@ -90,7 +89,7 @@ func createTestServer(
defer s.serverLock.Unlock()
if localUps != nil {
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
s.conf.UsePrivateRDNS = true
}
@@ -248,7 +247,7 @@ func TestServer(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
@@ -317,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
@@ -340,7 +339,7 @@ func TestDoTServer(t *testing.T) {
TLSListenAddrs: []*net.TCPAddr{{}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
@@ -370,7 +369,7 @@ func TestDoQServer(t *testing.T) {
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
@@ -414,7 +413,7 @@ func TestServerRace(t *testing.T) {
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
@@ -553,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := &aghtest.TestUpstream{
ups := &aghtest.Upstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
@@ -581,9 +580,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
}
// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work.
var testCNAMEs = map[string]string{
"badhost.": "NULL.example.org.",
"whitelist.example.org.": "NULL.example.org.",
var testCNAMEs = map[string][]string{
"badhost.": {"NULL.example.org."},
"whitelist.example.org.": {"NULL.example.org."},
}
// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work.
@@ -597,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
testUpstm := &aghtest.TestUpstream{
testUpstm := &aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
IPv6: nil,
@@ -631,7 +630,7 @@ func TestBlockCNAME(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
@@ -641,14 +640,17 @@ func TestBlockCNAME(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
testCases := []struct {
name string
host string
want bool
}{{
name: "block_request",
host: "badhost.",
// 'badhost' has a canonical name 'NULL.example.org' which is
// blocked by filters: response is blocked.
want: true,
}, {
name: "allowed",
host: "whitelist.example.org.",
// 'whitelist.example.org' has a canonical name
// 'NULL.example.org' which is blocked by filters
@@ -656,6 +658,7 @@ func TestBlockCNAME(t *testing.T) {
// response isn't blocked.
want: false,
}, {
name: "block_response",
host: "example.org.",
// 'example.org' has a canonical name 'cname1' with IP
// 127.0.0.255 which is blocked by filters: response is blocked.
@@ -663,9 +666,9 @@ func TestBlockCNAME(t *testing.T) {
}}
for _, tc := range testCases {
t.Run("block_cname_"+tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
req := createTestMessage(tc.host)
t.Run(tc.name, func(t *testing.T) {
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)
@@ -675,7 +678,7 @@ func TestBlockCNAME(t *testing.T) {
ans := reply.Answer[0]
a, ok := ans.(*dns.A)
require.Truef(t, ok, "got %T", ans)
require.True(t, ok)
assert.True(t, a.A.IsUnspecified())
}
@@ -696,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
@@ -893,7 +896,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
func TestRewrite(t *testing.T) {
c := &filtering.Config{
Rewrites: []filtering.RewriteEntry{{
Rewrites: []*filtering.LegacyRewrite{{
Domain: "test.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
@@ -908,6 +911,7 @@ func TestRewrite(t *testing.T) {
}},
}
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
@@ -931,9 +935,9 @@ func TestRewrite(t *testing.T) {
}))
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
CName: map[string]string{
"example.org": "somename",
&aghtest.Upstream{
CName: map[string][]string{
"example.org": {"somename"},
},
IPv4: map[string][]net.IP{
"example.org.": {{4, 3, 2, 1}},
@@ -944,45 +948,56 @@ func TestRewrite(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 1)
require.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
a, ok := reply.Answer[0].(*dns.A)
require.True(t, ok)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
assert.Empty(t, reply.Answer)
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
require.NoError(t, err)
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, eerr)
// The original question is restored.
require.Len(t, reply.Question, 1)
// The original question is restored.
require.Len(t, reply.Question, 1)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name)
require.Len(t, reply.Answer, 2)
require.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func publicKey(priv interface{}) interface{} {
@@ -1036,9 +1051,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
err = s.Start()
require.NoError(t, err)
t.Cleanup(func() {
s.Close()
})
t.Cleanup(s.Close)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
@@ -1057,23 +1070,40 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
}
func TestPTRResponseFromHosts(t *testing.T) {
c := filtering.Config{
EtcHosts: &aghnet.EtcHostsContainer{},
// Prepare test hosts file.
const hostsFilename = "hosts"
testFS := fstest.MapFS{
hostsFilename: &fstest.MapFile{Data: []byte(`
127.0.0.1 host # comment
::1 localhost#comment
`)},
}
// Prepare test hosts file.
hf, err := os.CreateTemp("", "")
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return nil
},
OnAdd: func(name string) (err error) {
assert.Equal(t, hostsFilename, name)
return nil
},
OnClose: func() (err error) { panic("not implemented") },
}, hostsFilename)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, hf.Close())
assert.NoError(t, os.Remove(hf.Name()))
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
})
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
_, _ = hf.WriteString(" ::1 localhost#comment \n")
c.EtcHosts.Init(hf.Name())
t.Cleanup(c.EtcHosts.Close)
flt := filtering.New(&filtering.Config{
EtcHosts: hc,
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
@@ -1083,7 +1113,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: filtering.New(&c, nil),
DNSFilter: flt,
SubnetDetector: snd,
})
require.NoError(t, err)
@@ -1091,32 +1121,39 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err = s.Prepare(nil)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(s.Close)
t.Cleanup(func() {
s.Close()
})
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, eerr)
resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
require.Len(t, resp.Answer, 1)
require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
}
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr)
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func TestNewServer(t *testing.T) {
@@ -1154,23 +1191,18 @@ func TestNewServer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewServer(tc.in)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.TestUpstream{
extUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
},
}
locUpstream := &aghtest.TestUpstream{
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.168.192.in-addr.arpa.": {"local.domain"},
"2.1.168.192.in-addr.arpa.": {},

View File

@@ -61,8 +61,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.NoError(t, err)
require.Nil(t, err)
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
})
@@ -72,7 +72,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
assert.Empty(t, d.Res.Answer)
})
@@ -83,7 +84,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -96,7 +98,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -109,7 +112,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -122,7 +126,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -135,7 +140,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -152,7 +158,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -171,7 +178,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)
@@ -190,7 +198,8 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
require.Len(t, d.Res.Answer, 1)

View File

@@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
setts.ProtectionEnabled = ctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts)
@@ -65,42 +66,23 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
d := ctx.proxyCtx
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts)
if err != nil {
// Return immediately if there's an error
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err)
} else if res.IsFiltered {
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".")
res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
switch {
case err != nil:
return nil, fmt.Errorf("failed to check host %q: %w", host, err)
case res.IsFiltered:
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0 {
// Resolve the new canonical name, not the original host
// name. The original question is readded in
// processFilteringAfterResponse.
ctx.origQuestion = req.Question[0]
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse.
ctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
resp := s.makeResponse(req)
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr := &dns.PTR{
Hdr: hdr,
Ptr: h,
}
resp.Answer = append(resp.Answer, ptr)
}
d.Res = resp
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) {
case res.Reason == filtering.Rewritten:
resp := s.makeResponse(req)
name := host
@@ -110,11 +92,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA {
switch q.Qtype {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA {
case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
@@ -122,9 +105,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
}
d.Res = resp
} else if res.Reason == filtering.RewrittenRule {
err = s.filterDNSRewrite(req, res, d)
if err != nil {
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, d); err != nil {
return nil, err
}
}
@@ -134,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) (
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
r *filtering.Result,
err error,
) {
@@ -146,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
}
var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts)
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil {
return nil, err
}
@@ -154,32 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
return &res, err
}
// If response contains CNAME, A or AAAA records, we apply filtering to each
// canonical host name or IP address. If this is a match, we set a new response
// in d.Res and return.
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
// filterDNSResponse checks each resource record of the response's answer
// section from ctx and returns a non-nil res if at least one of canonnical
// names or IP addresses in it matches the filtering rules.
func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
d := ctx.proxyCtx
setts := ctx.setts
if !setts.FilteringEnabled {
return nil, nil
}
for _, a := range d.Res.Answer {
host := ""
switch v := a.(type) {
var rrtype uint16
switch a := a.(type) {
case *dns.CNAME:
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name)
host = strings.TrimSuffix(v.Target, ".")
host = strings.TrimSuffix(a.Target, ".")
rrtype = dns.TypeCNAME
case *dns.A:
host = v.A.String()
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name)
host = a.A.String()
rrtype = dns.TypeA
case *dns.AAAA:
host = v.AAAA.String()
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name)
host = a.AAAA.String()
rrtype = dns.TypeAAAA
default:
continue
}
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
res, err = s.checkHostRules(host, rrtype, setts)
if err != nil {
return nil, err
} else if res == nil {

View File

@@ -0,0 +1,159 @@
package dnsforward
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
rules := `
||blocked.domain^
@@||allowed.domain^
||cname.specific^$dnstype=~CNAME
||0.0.0.1^$dnstype=~A
||::1^$dnstype=~AAAA
`
forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
BlockingMode: BlockingModeDefault,
},
}
filters := []filtering.Filter{{
ID: 0, Data: []byte(rules),
}}
f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
})
require.NoError(t, err)
s.conf = forwardConf
err = s.Prepare(nil)
require.NoError(t, err)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: map[string][]string{
"cname.exception.": {"cname.specific."},
"should.block.": {"blocked.domain."},
"allowed.first.": {"allowed.domain.", "blocked.domain."},
"blocked.first.": {"blocked.domain.", "allowed.domain."},
},
IPv4: map[string][]net.IP{
"a.exception.": {{0, 0, 0, 1}},
},
IPv6: map[string][]net.IP{
"aaaa.exception.": {net.ParseIP("::1")},
},
},
}
startDeferStop(t, s)
testCases := []struct {
req *dns.Msg
name string
wantAns []dns.RR
}{{
req: createTestMessage("cname.exception."),
name: "cname_exception",
wantAns: []dns.RR{&dns.CNAME{
Hdr: dns.RR_Header{
Name: "cname.exception.",
Rrtype: dns.TypeCNAME,
},
Target: "cname.specific.",
}},
}, {
req: createTestMessage("should.block."),
name: "blocked_by_cname",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "should.block.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("a.exception."),
name: "a_exception",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "a.exception.",
Rrtype: dns.TypeA,
},
A: net.IP{0, 0, 0, 1},
}},
}, {
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
name: "aaaa_exception",
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: "aaaa.exception.",
Rrtype: dns.TypeAAAA,
},
AAAA: net.ParseIP("::1"),
}},
}, {
req: createTestMessage("allowed.first."),
name: "allowed_first",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "allowed.first.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}, {
req: createTestMessage("blocked.first."),
name: "blocked_first",
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "blocked.first.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: netutil.IPv4Zero(),
}},
}}
for _, tc := range testCases {
dctx := &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
}
t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, dctx)
require.NoError(t, err)
require.NotNil(t, dctx.Res)
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
})
}
}

View File

@@ -5,10 +5,12 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -18,12 +20,6 @@ import (
"github.com/miekg/dns"
)
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("dns: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type dnsConfig struct {
Upstreams *[]string `json:"upstream_dns"`
UpstreamsFile *string `json:"upstream_dns_file"`
@@ -47,7 +43,7 @@ type dnsConfig struct {
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
}
func (s *Server) getDNSConfig() dnsConfig {
func (s *Server) getDNSConfig() (c *dnsConfig) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
@@ -76,7 +72,7 @@ func (s *Server) getDNSConfig() dnsConfig {
upstreamMode = "parallel"
}
return dnsConfig{
return &dnsConfig{
Upstreams: &upstreams,
UpstreamsFile: &upstreamFile,
Bootstraps: &bootstraps,
@@ -112,14 +108,15 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
// since there is no need to omit it while decoding from JSON.
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
}{
dnsConfig: s.getDNSConfig(),
dnsConfig: *s.getDNSConfig(),
DefautLocalPTRUpstreams: defLocalPTRUps,
}
w.Header().Set("Content-Type", "application/json")
if err = json.NewEncoder(w).Encode(resp); err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encoder: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encoder: %s", err)
return
}
}
@@ -143,39 +140,63 @@ func (req *dnsConfig) checkBlockingMode() bool {
}
func (req *dnsConfig) checkUpstreamsMode() bool {
if req.UpstreamMode == nil {
return true
}
valid := []string{"", "fastest_addr", "parallel"}
for _, valid := range []string{
"",
"fastest_addr",
"parallel",
} {
if *req.UpstreamMode == valid {
return true
}
}
return false
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
}
func (req *dnsConfig) checkBootstrap() (string, error) {
func (req *dnsConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return "", nil
return nil
}
for _, boot := range *req.Bootstraps {
if boot == "" {
return boot, fmt.Errorf("invalid bootstrap server address: empty")
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
if _, err := upstream.NewResolver(boot, nil); err != nil {
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
if _, err = upstream.NewResolver(b, nil); err != nil {
return err
}
}
return "", nil
return nil
}
// validate returns an error if any field of req is invalid.
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
if err != nil {
return fmt.Errorf("validating upstream servers: %w", err)
}
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err)
}
}
err = req.checkBootstrap()
if err != nil {
return err
}
switch {
case !req.checkBlockingMode():
return errors.Error("blocking_mode: incorrect value")
case !req.checkUpstreamsMode():
return errors.Error("upstream_mode: incorrect value")
case !req.checkCacheTTL():
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
default:
return nil
}
}
func (req *dnsConfig) checkCacheTTL() bool {
@@ -195,37 +216,18 @@ func (req *dnsConfig) checkCacheTTL() bool {
}
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := dnsConfig{}
dec := json.NewDecoder(r.Body)
if err := dec.Decode(&req); err != nil {
httpError(r, w, http.StatusBadRequest, "json Encode: %s", err)
req := &dnsConfig{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
if req.Upstreams != nil {
if err := ValidateUpstreams(*req.Upstreams); err != nil {
httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
return
}
}
err = req.validate(s.subnetDetector)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
if errBoot, err := req.checkBootstrap(); err != nil {
httpError(r, w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", errBoot, err)
return
}
if !req.checkBlockingMode() {
httpError(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
return
}
if !req.checkUpstreamsMode() {
httpError(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
return
}
if !req.checkCacheTTL() {
httpError(r, w, http.StatusBadRequest, "cache_ttl_min must be less or equal than cache_ttl_max")
return
}
@@ -233,14 +235,14 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
s.conf.ConfigModified()
if restart {
if err := s.Reconfigure(nil); err != nil {
httpError(r, w, http.StatusInternalServerError, "%s", err)
return
err = s.Reconfigure(nil)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
}
}
}
func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) {
if dc.Upstreams != nil {
s.conf.UpstreamDNS = *dc.Upstreams
restart = true
@@ -261,9 +263,9 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
restart = true
}
if dc.RateLimit != nil {
restart = restart || s.conf.Ratelimit != *dc.RateLimit
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
s.conf.Ratelimit = *dc.RateLimit
restart = true
}
if dc.EDNSCSEnabled != nil {
@@ -294,7 +296,7 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
return restart
}
func (s *Server) setConfig(dc dnsConfig) (restart bool) {
func (s *Server) setConfig(dc *dnsConfig) (restart bool) {
s.serverLock.Lock()
defer s.serverLock.Unlock()
@@ -341,103 +343,180 @@ type upstreamJSON struct {
PrivateUpstreams []string `json:"private_upstream"`
}
// IsCommentOrEmpty returns true of the string starts with a "#" character or is
// an empty string. This function is useful for filtering out non-upstream
// lines from upstream configs.
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream
// configs.
func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// LocalNetChecker is used to check if the IP address belongs to a local
// network.
type LocalNetChecker interface {
// IsLocallyServedNetwork returns true if ip is contained in any of address
// registries defined by RFC 6303.
IsLocallyServedNetwork(ip net.IP) (ok bool)
}
// type check
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
// slice already so that this function may be considered useless.
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
// No need to validate comments and empty lines.
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
if len(upstreams) == 0 {
// Consider this case valid since it means the default server should be
// used.
return nil, nil
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
)
if err != nil {
return nil, err
} else if len(conf.Upstreams) == 0 {
return nil, errors.Error("no default upstreams specified")
}
for _, u := range upstreams {
_, err = validateUpstream(u)
if err != nil {
return nil, err
}
}
return conf, nil
}
// ValidateUpstreams validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified.
//
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
func ValidateUpstreams(upstreams []string) (err error) {
// No need to validate comments
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
_, err = newUpstreamConfig(upstreams)
// Consider this case valid because defaultDNS will be used
if len(upstreams) == 0 {
return nil
return err
}
// stringKeysSorted returns the sorted slice of string keys of m.
//
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
sorted = make([]string, 0, len(m))
for s := range m {
sorted = append(sorted, s)
}
_, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: []string{},
Timeout: DefaultTimeout,
},
)
sort.Strings(sorted)
return sorted
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. lnc must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return err
}
var defaultUpstreamFound bool
for _, u := range upstreams {
var ok bool
ok, err = validateUpstream(u)
if conf == nil {
return nil
}
var errs []error
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain)
if err != nil {
return err
errs = append(errs, err)
continue
}
if !defaultUpstreamFound {
defaultUpstreamFound = ok
if !lnc.IsLocallyServedNetwork(subnet.IP) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
if !defaultUpstreamFound {
return fmt.Errorf("no default upstreams specified")
if len(errs) > 0 {
return errors.List("checking domain-specific upstreams", errs...)
}
return nil
}
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
func validateUpstream(u string) (bool, error) {
func validateUpstream(u string) (useDefault bool, err error) {
// Check if the user tries to specify upstream for domain.
u, useDefault, err := separateUpstream(u)
var isDomainSpec bool
u, isDomainSpec, err = separateUpstream(u)
if err != nil {
return useDefault, err
return !isDomainSpec, err
}
// The special server address '#' means "use the default servers"
if u == "#" && !useDefault {
// The special server address '#' means that default server must be used.
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
return useDefault, nil
}
// Check if the upstream has a valid protocol prefix
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
for _, proto := range protocols {
if strings.HasPrefix(u, proto) {
return useDefault, nil
}
}
// Return error if the upstream contains '://' without any valid protocol
if strings.Contains(u, "://") {
return useDefault, fmt.Errorf("wrong protocol")
return useDefault, errors.Error("wrong protocol")
}
// Check if upstream is valid plain DNS
return useDefault, checkPlainDNS(u)
// Check if upstream is either an IP or IP with port.
if net.ParseIP(u) != nil {
return useDefault, nil
} else if _, err = netutil.ParseIPPort(u); err != nil {
return useDefault, err
}
return useDefault, nil
}
// separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
// isDomainSpec is true when the upstream is domains-specific.
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil
return upstreamStr, false, nil
}
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 {
return "", false, errors.Error("duplicated separator")
switch len(parts) {
case 2:
// Go on.
case 1:
return "", false, errors.Error("missing separator")
default:
return "", true, errors.Error("duplicated separator")
}
domains := parts[0]
upstream = parts[1]
var domains string
domains, upstream = parts[0], parts[1]
for i, host := range strings.Split(domains, "/") {
if host == "" {
continue
@@ -445,36 +524,11 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
err = netutil.ValidateDomainName(host)
if err != nil {
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
}
}
return upstream, false, nil
}
// checkPlainDNS checks if host is plain DNS
func checkPlainDNS(upstream string) error {
// Check if host is ip without port
if net.ParseIP(upstream) != nil {
return nil
}
// Check if host is ip with port
ip, port, err := net.SplitHostPort(upstream)
if err != nil {
return err
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not a valid IP", ip)
}
_, err = strconv.ParseInt(port, 0, 64)
if err != nil {
return fmt.Errorf("%s is not a valid port: %w", port, err)
}
return nil
return upstream, true, nil
}
// excFunc is a signature of function to check if upstream exchanges correctly.
@@ -502,12 +556,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
if len(reply.Answer) != 1 {
return fmt.Errorf("wrong response")
}
if t, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("wrong response")
}
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
return fmt.Errorf("wrong response")
}
return nil
@@ -542,7 +592,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
// Separate upstream from domains list.
var useDefault bool
if input, useDefault, err = separateUpstream(input); err != nil {
if useDefault, err = validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
@@ -551,7 +601,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return nil
}
if _, err = validateUpstream(input); err != nil {
if input, _, err = separateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
@@ -559,7 +609,8 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
bootstrap = defaultBootstrap
}
log.Debug("checking if dns server %q works...", input)
log.Debug("checking if upstream %s works", input)
var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap,
@@ -573,7 +624,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
}
log.Debug("dns %s works OK", input)
log.Debug("upstream %s is ok", input)
return nil
}
@@ -582,7 +633,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
req := &upstreamJSON{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
@@ -607,9 +658,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc)
if err != nil {
log.Info("%v", err)
// TODO(e.burkov): If passed upstream have already
// written an error above, we rewriting the error for
// it. These cases should be handled properly instead.
// TODO(e.burkov): If passed upstream have already written an error
// above, we rewriting the error for it. These cases should be
// handled properly instead.
result[host] = err.Error()
continue
@@ -620,7 +671,13 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(result)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal status json: %s",
err,
)
return
}
@@ -628,9 +685,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
return
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@@ -641,12 +696,12 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// -> dnsforward.handleDNSRequest
func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil {
httpError(r, w, http.StatusNotFound, "Not Found")
aghhttp.Error(r, w, http.StatusNotFound, "Not Found")
return
}
if !s.IsRunning() {
httpError(r, w, http.StatusInternalServerError, "dns server is not running")
aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running")
return
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -38,9 +39,7 @@ func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
var f *os.File
f, err := os.Open(filepath.Join("testdata", casesFileName))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, f.Close())
})
testutil.CleanupAndRequireSuccess(t, f.Close)
err = json.NewDecoder(f).Decode(cases)
require.NoError(t, err)
@@ -69,10 +68,8 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &fakeSystemResolvers{}
require.Nil(t, s.Start())
t.Cleanup(func() {
require.Nil(t, s.Stop())
})
require.NoError(t, s.Start())
testutil.CleanupAndRequireSuccess(t, s.Stop)
defaultConf := s.conf
@@ -147,10 +144,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
defaultConf := s.conf
err := s.Start()
assert.Nil(t, err)
t.Cleanup(func() {
assert.Nil(t, s.Stop())
})
assert.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop)
w := httptest.NewRecorder()
@@ -189,12 +184,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "upstream_dns_bad",
wantSet: `wrong upstreams specification: address !!!: ` +
`missing port in address`,
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
`address !!!: missing port in address`,
}, {
name: "bootstraps_bad",
wantSet: `a can not be used as bootstrap dns cause: ` +
`invalid bootstrap server address: ` +
wantSet: `checking bootstrap a: invalid address: ` +
`Resolver a is not eligible to be a bootstrap DNS server`,
}, {
name: "cache_bad_ttl",
@@ -205,6 +199,10 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}, {
name: "local_ptr_upstreams_good",
wantSet: "",
}, {
name: "local_ptr_upstreams_bad",
wantSet: `validating private upstream servers: checking domain-specific upstreams: ` +
`bad arpa domain name "non.arpa": not a reversed ip network`,
}, {
name: "local_ptr_upstreams_null",
wantSet: "",
@@ -221,14 +219,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
require.True(t, ok)
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.conf = defaultConf
})
t.Cleanup(func() { s.conf = defaultConf })
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "http://example.com", rBody)
require.Nil(t, err)
require.NoError(t, err)
s.handleSetConfig(w, r)
assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n"))
@@ -242,130 +238,145 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}
func TestIsCommentOrEmpty(t *testing.T) {
assert.True(t, IsCommentOrEmpty(""))
assert.True(t, IsCommentOrEmpty("# comment"))
assert.False(t, IsCommentOrEmpty("1.2.3.4"))
for _, tc := range []struct {
want assert.BoolAssertionFunc
str string
}{{
want: assert.True,
str: "",
}, {
want: assert.True,
str: "# comment",
}, {
want: assert.False,
str: "1.2.3.4",
}} {
tc.want(t, IsCommentOrEmpty(tc.str))
}
}
// TODO(a.garipov): Rewrite to check the actual error messages.
func TestValidateUpstream(t *testing.T) {
testCases := []struct {
wantDef assert.BoolAssertionFunc
name string
upstream string
valid bool
wantDef bool
wantErr string
}{{
wantDef: assert.True,
name: "invalid",
upstream: "1.2.3.4.5",
valid: false,
wantDef: false,
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "123.3.7m",
valid: false,
wantDef: false,
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "htttps://google.com/dns-query",
valid: false,
wantDef: false,
wantErr: `wrong protocol`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[/host.com]tls://dns.adguard.com",
valid: false,
wantDef: false,
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[host.ru]#",
valid: false,
wantDef: false,
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "1.1.1.1",
valid: true,
wantDef: true,
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "tls://1.1.1.1",
valid: true,
wantDef: true,
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "https://dns.adguard.com/dns-query",
valid: true,
wantDef: true,
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true,
wantDef: true,
wantErr: ``,
}, {
wantDef: assert.True,
name: "default_udp_host",
upstream: "udp://dns.google",
}, {
wantDef: assert.True,
name: "default_udp_ip",
upstream: "udp://8.8.8.8",
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/]1.1.1.1",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[//]tls://1.1.1.1",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/www.host.com/]#",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/google.com/]8.8.8.8",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "idna",
upstream: "[/пример.рф/]8.8.8.8",
valid: true,
wantDef: false,
wantErr: ``,
}, {
wantDef: assert.False,
name: "bad_domain",
upstream: "[/!/]8.8.8.8",
valid: false,
wantDef: false,
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream)
require.Equal(t, tc.valid, err == nil)
if tc.valid {
assert.Equal(t, tc.wantDef, defaultUpstream)
}
testutil.AssertErrorMsg(t, tc.wantErr, err)
tc.wantDef(t, defaultUpstream)
})
}
}
func TestValidateUpstreamsSet(t *testing.T) {
func TestValidateUpstreams(t *testing.T) {
testCases := []struct {
name string
msg string
wantErr string
set []string
wantNil bool
}{{
name: "empty",
msg: "empty upstreams array should be valid",
wantErr: ``,
set: nil,
wantNil: true,
}, {
name: "comment",
msg: "comments should not be validated",
wantErr: ``,
set: []string{"# comment"},
wantNil: true,
}, {
name: "valid_no_default",
msg: "there is no default upstream",
name: "valid_no_default",
wantErr: `no default upstreams specified`,
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
@@ -373,10 +384,9 @@ func TestValidateUpstreamsSet(t *testing.T) {
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
},
wantNil: false,
}, {
name: "valid_with_default",
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
name: "valid_with_default",
wantErr: ``,
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
@@ -385,19 +395,65 @@ func TestValidateUpstreamsSet(t *testing.T) {
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"8.8.8.8",
},
wantNil: true,
}, {
name: "invalid",
msg: "there is an invalid upstream in set, but it pass through validation",
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
set: []string{"dhcp://fake.dns"},
wantNil: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreams(tc.set)
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func TestValidateUpstreamsPrivate(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: 2 errors: ` +
`"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` +
`"bad arpa domain name \"non.arpa\": not a reversed ip network"`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err = ValidateUpstreamsPrivate(set, snd)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}

View File

@@ -24,20 +24,19 @@ type ipsetCtx struct {
func (c *ipsetCtx) init(ipsetConf []string) (err error) {
c.ipsetMgr, err = aghnet.NewIpsetManager(ipsetConf)
if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
// ipset cannot currently be initialized if the server was
// installed from Snap or when the user or the binary doesn't
// have the required permissions, or when the kernel doesn't
// support netfilter.
// ipset cannot currently be initialized if the server was installed
// from Snap or when the user or the binary doesn't have the required
// permissions, or when the kernel doesn't support netfilter.
//
// Log and go on.
//
// TODO(a.garipov): The Snap problem can probably be solved if
// we add the netlink-connector interface plug.
log.Info("warning: cannot initialize ipset: %s", err)
// TODO(a.garipov): The Snap problem can probably be solved if we add
// the netlink-connector interface plug.
log.Info("ipset: warning: cannot initialize: %s", err)
return nil
} else if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) {
log.Info("warning: %s", err)
log.Info("ipset: warning: %s", err)
return nil
} else if err != nil {

View File

@@ -78,8 +78,8 @@ func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDet
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the
// real machine's endianess is needed.
// The binary.BigEndian byte order is used everywhere except when the real
// machine's endianness is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"net"
"strings"
"time"
@@ -8,15 +9,15 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
pctx := ctx.proxyCtx
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
elapsed := time.Since(dctx.startTime)
pctx := dctx.proxyCtx
shouldLog := true
msg := pctx.Req
@@ -26,54 +27,79 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
shouldLog = false
}
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
ip = netutil.CloneIP(ip)
s.serverLock.RLock()
defer s.serverLock.RUnlock()
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
s.anonymizer.Load()(ip)
log.Debug("client ip: %s", ip)
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
// uninitialized while in use. This can happen after proxy server has been
// stopped, but its workers haven't yet exited.
if shouldLog && s.queryLog != nil {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
p := querylog.AddParams{
Question: msg,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: ip,
ClientID: ctx.clientID,
}
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDoH
case proxy.ProtoQUIC:
p.ClientProto = querylog.ClientProtoDoQ
case proxy.ProtoTLS:
p.ClientProto = querylog.ClientProtoDoT
case proxy.ProtoDNSCrypt:
p.ClientProto = querylog.ClientProtoDNSCrypt
default:
// Consider this a plain DNS-over-UDP or DNS-over-TCP
// request.
}
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
}
s.queryLog.Add(p)
s.logQuery(dctx, pctx, elapsed, ip)
}
s.updateStats(ctx, elapsed, *ctx.result)
if s.stats != nil {
s.updateStats(dctx, elapsed, *dctx.result, ip)
}
return resultCodeSuccess
}
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filtering.Result) {
if s.stats == nil {
return
// logQuery pushes the request details into the query log.
func (s *Server) logQuery(
dctx *dnsContext,
pctx *proxy.DNSContext,
elapsed time.Duration,
ip net.IP,
) {
p := &querylog.AddParams{
Question: pctx.Req,
ReqECS: pctx.ReqECS,
Answer: pctx.Res,
OrigAnswer: dctx.origResp,
Result: dctx.result,
Elapsed: elapsed,
ClientID: dctx.clientID,
ClientIP: ip,
AuthenticatedData: dctx.responseAD,
}
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDoH
case proxy.ProtoQUIC:
p.ClientProto = querylog.ClientProtoDoQ
case proxy.ProtoTLS:
p.ClientProto = querylog.ClientProtoDoT
case proxy.ProtoDNSCrypt:
p.ClientProto = querylog.ClientProtoDNSCrypt
default:
// Consider this a plain DNS-over-UDP or DNS-over-TCP request.
}
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
} else if cachedUps := pctx.CachedUpstreamAddr; cachedUps != "" {
p.Upstream = pctx.CachedUpstreamAddr
p.Cached = true
}
s.queryLog.Add(p)
}
// updatesStats writes the request into statistics.
func (s *Server) updateStats(
ctx *dnsContext,
elapsed time.Duration,
res filtering.Result,
clientIP net.IP,
) {
pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
@@ -81,8 +107,8 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip, _ := netutil.IPAndPortFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
} else if clientIP != nil {
e.Client = clientIP.String()
}
e.Time = uint32(elapsed / 1000)

View File

@@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
@@ -21,11 +22,11 @@ type testQueryLog struct {
// a querylog.QueryLog without actually implementing all methods.
querylog.QueryLog
lastParams querylog.AddParams
lastParams *querylog.AddParams
}
// Add implements the querylog.QueryLog interface for *testQueryLog.
func (l *testQueryLog) Add(p querylog.AddParams) {
func (l *testQueryLog) Add(p *querylog.AddParams) {
l.lastParams = p
}
@@ -65,7 +66,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
reason: filtering.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls_client_id",
name: "success_tls_clientid",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "cli42",
@@ -157,9 +158,16 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
require.Nil(t, err)
require.NoError(t, err)
for _, tc := range testCases {
ql := &testQueryLog{}
st := &testStats{}
srv := &Server{
queryLog: ql,
stats: st,
anonymizer: aghnet.NewIPMut(nil),
}
t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{
Question: []dns.Question{{
@@ -173,14 +181,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
Addr: tc.addr,
Upstream: ups,
}
ql := &testQueryLog{}
st := &testStats{}
dctx := &dnsContext{
srv: &Server{
queryLog: ql,
stats: st,
},
proxyCtx: pctx,
startTime: time.Now(),
result: &filtering.Result{
@@ -189,7 +190,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
clientID: tc.clientID,
}
code := processQueryLogsAndStats(dctx)
code := srv.processQueryLogsAndStats(dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)

View File

@@ -520,6 +520,43 @@
]
}
},
"local_ptr_upstreams_bad": {
"req": {
"local_ptr_upstreams": [
"123.123.123.123",
"[/non.arpa/]#"
]
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"protection_enabled": true,
"ratelimit": 0,
"blocking_mode": "",
"blocking_ipv4": "",
"blocking_ipv6": "",
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": []
}
},
"local_ptr_upstreams_null": {
"req": {
"local_ptr_upstreams": null

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
)
@@ -239,7 +240,7 @@ func initBlockedServices() {
for _, s := range serviceRulesArray {
netRules := []*rules.NetworkRule{}
for _, text := range s.rules {
rule, err := rules.NewNetworkRule(text, 0)
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
if err != nil {
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
continue
@@ -287,7 +288,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
@@ -296,7 +298,8 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
list := []string{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}

View File

@@ -15,14 +15,10 @@ type DNSRewriteResult struct {
// the server returns.
type DNSRewriteResultResponse map[rules.RRType][]rules.RRValue
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns
// an empty result if dnsr is empty. Otherwise, the result will have
// either CanonName or DNSRewriteResult set.
// processDNSRewrites processes DNS rewrite rules in dnsr. It returns an empty
// result if dnsr is empty. Otherwise, the result will have either CanonName or
// DNSRewriteResult set. dnsr is expected to be non-empty.
func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
if len(dnsr) == 0 {
return Result{}
}
var rules []*ResultRule
dnsrr := &DNSRewriteResult{
Response: DNSRewriteResultResponse{},
@@ -31,8 +27,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
for _, nr := range dnsr {
dr := nr.DNSRewrite
if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than
// the other rules.
// NewCNAME rules have a higher priority than other rules.
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,
@@ -54,8 +49,8 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Text: nr.RuleText,
})
default:
// RcodeRefused and other such codes have higher
// priority. Return immediately.
// RcodeRefused and other such codes have higher priority. Return
// immediately.
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,

View File

@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
`
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
FilteringEnabled: true,
}

View File

@@ -4,6 +4,7 @@ package filtering
import (
"context"
"fmt"
"io/fs"
"net"
"net/http"
"os"
@@ -16,6 +17,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
@@ -24,6 +26,18 @@ import (
"github.com/miekg/dns"
)
// The IDs of built-in filter lists.
//
// Keep in sync with client/src/helpers/constants.js.
const (
CustomListID = -iota
SysHostsListID
BlockedSvcsListID
ParentalListID
SafeBrowsingListID
SafeSearchListID
)
// ServiceEntry - blocked service array element
type ServiceEntry struct {
Name string
@@ -38,6 +52,7 @@ type Settings struct {
ServicesRules []ServiceEntry
ProtectionEnabled bool
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
@@ -65,7 +80,7 @@ type Config struct {
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
Rewrites []RewriteEntry `yaml:"rewrites"`
Rewrites []*LegacyRewrite `yaml:"rewrites"`
// Names of services to block (globally).
// Per-client settings can override this configuration.
@@ -73,7 +88,7 @@ type Config struct {
// EtcHosts is a container of IP-hostname pairs taken from the operating
// system configuration files (e.g. /etc/hosts).
EtcHosts *aghnet.EtcHostsContainer `yaml:"-"`
EtcHosts *aghnet.HostsContainer `yaml:"-"`
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
@@ -124,6 +139,10 @@ type DNSFilter struct {
parentalUpstream upstream.Upstream
safeBrowsingUpstream upstream.Upstream
safebrowsingCache cache.Cache
parentalCache cache.Cache
safeSearchCache cache.Cache
Config // for direct access by library users, even a = assignment
// confLock protects Config.
confLock sync.RWMutex
@@ -142,9 +161,14 @@ type DNSFilter struct {
// Filter represents a filter list
type Filter struct {
ID int64 // auto-assigned when filter is added (see nextFilterID)
Data []byte `yaml:"-"` // List of rules divided by '\n'
FilePath string `yaml:"-"` // Path to a filtering rules file
// FilePath is the path to a filtering rules list file.
FilePath string `yaml:"-"`
// Data is the content of the file.
Data []byte `yaml:"-"`
// ID is automatically assigned when filter is added using nextFilterID.
ID int64
}
// Reason holds an enum detailing why it was filtered or not filtered
@@ -176,8 +200,8 @@ const (
// FilteredBlockedService - the host is blocked by "blocked services" settings
FilteredBlockedService
// Rewritten is returned when there was a rewrite by a legacy DNS
// rewrite rule.
// Rewritten is returned when there was a rewrite by a legacy DNS rewrite
// rule.
Rewritten
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
@@ -186,8 +210,8 @@ const (
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
//
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
// their functionality into RewrittenRule.
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging their
// functionality into RewrittenRule.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
RewrittenRule
@@ -221,12 +245,13 @@ func (r Reason) String() string {
}
// In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) bool {
func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons {
if r == reason {
return true
}
}
return false
}
@@ -245,7 +270,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
defer d.confLock.RUnlock()
return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1,
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled,
@@ -261,8 +286,14 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
c.Rewrites = cloneRewrites(c.Rewrites)
}
func cloneRewrites(entries []RewriteEntry) (clone []RewriteEntry) {
return append([]RewriteEntry(nil), entries...)
// cloneRewrites returns a deep copy of entries.
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
clone = make([]*LegacyRewrite, len(entries))
for i, rw := range entries {
clone[i] = rw.clone()
}
return clone
}
// SetFilters - set new filters (synchronously or asynchronously)
@@ -338,14 +369,6 @@ func (d *DNSFilter) reset() {
}
}
type dnsFilterContext struct {
safebrowsingCache cache.Cache
parentalCache cache.Cache
safeSearchCache cache.Cache
}
var gctx dnsFilterContext
// ResultRule contains information about applied rules.
type ResultRule struct {
// Text is the text of the rule.
@@ -371,24 +394,19 @@ type Result struct {
// Reason is the reason for blocking or unblocking the request.
Reason Reason `json:",omitempty"`
// Rules are applied rules. If Rules are not empty, each rule
// is not nil.
// Rules are applied rules. If Rules are not empty, each rule is not nil.
Rules []*ResultRule `json:",omitempty"`
// ReverseHosts is the reverse lookup rewrite result. It is
// empty unless Reason is set to RewrittenAutoHosts.
ReverseHosts []string `json:",omitempty"`
// IPList is the lookup rewrite result. It is empty unless
// Reason is set to RewrittenAutoHosts or Rewritten.
// IPList is the lookup rewrite result. It is empty unless Reason is set to
// Rewritten.
IPList []net.IP `json:",omitempty"`
// CanonName is the CNAME value from the lookup rewrite result.
// It is empty unless Reason is set to Rewritten or RewrittenRule.
// CanonName is the CNAME value from the lookup rewrite result. It is empty
// unless Reason is set to Rewritten or RewrittenRule.
CanonName string `json:",omitempty"`
// ServiceName is the name of the blocked service. It is empty
// unless Reason is set to FilteredBlockedService.
// ServiceName is the name of the blocked service. It is empty unless
// Reason is set to FilteredBlockedService.
ServiceName string `json:",omitempty"`
// DNSRewriteResult is the $dnsrewrite filter rule result.
@@ -402,14 +420,8 @@ func (r Reason) Matched() bool {
}
// CheckHostRules tries to match the host against filtering rules only.
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
host = strings.ToLower(host)
return d.matchHost(host, qtype, setts)
func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) {
return d.matchHost(strings.ToLower(host), rrtype, setts)
}
// CheckHost tries to match the host against filtering rules, then safebrowsing
@@ -422,14 +434,16 @@ func (d *DNSFilter) CheckHost(
// Sometimes clients try to resolve ".", which is a request to get root
// servers.
if host == "" {
return Result{Reason: NotFilteredNotFound}, nil
return Result{}, nil
}
host = strings.ToLower(host)
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
if setts.FilteringEnabled {
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
}
}
for _, hc := range d.hostCheckers {
@@ -446,100 +460,123 @@ func (d *DNSFilter) CheckHost(
return Result{}, nil
}
// checkEtcHosts compares the host against our /etc/hosts table. The err is
// always nil, it is only there to make this a valid hostChecker function.
func (d *DNSFilter) checkEtcHosts(
// matchSysHosts tries to match the host against the operating system's hosts
// database. err is always nil.
func (d *DNSFilter) matchSysHosts(
host string,
qtype uint16,
_ *Settings,
setts *Settings,
) (res Result, err error) {
if d.Config.EtcHosts == nil {
return Result{}, nil
}
ips := d.Config.EtcHosts.Process(host, qtype)
if ips != nil {
res = Result{
Reason: RewrittenAutoHosts,
IPList: ips,
}
if !setts.FilteringEnabled || d.EtcHosts == nil {
return res, nil
}
revHosts := d.Config.EtcHosts.ProcessReverse(host, qtype)
if len(revHosts) != 0 {
res = Result{
Reason: RewrittenAutoHosts,
}
// TODO(a.garipov): Optimize this with a buffer.
res.ReverseHosts = make([]string, len(revHosts))
for i := range revHosts {
res.ReverseHosts[i] = revHosts[i] + "."
}
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
})
if dnsres == nil {
return res, nil
}
return Result{}, nil
dnsr := dnsres.DNSRewrites()
if len(dnsr) == 0 {
return res, nil
}
res = d.processDNSRewrites(dnsr)
res.Reason = RewrittenAutoHosts
for _, r := range res.Rules {
r.Text = stringutil.Coalesce(d.EtcHosts.Translate(r.Text), r.Text)
}
return res, nil
}
// Process rewrites table
// . Find CNAME for a domain name (exact match or by wildcard)
// . if found and CNAME equals to domain name - this is an exception; exit
// . if found, set domain name to canonical name
// . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array
// processRewrites performs filtering based on the legacy rewrite records.
//
// Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host,
// this query isn't filtered. If it's different, repeat the process for the new
// CNAME, breaking loops in the process.
//
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
// result is an exception.
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
d.confLock.RLock()
defer d.confLock.RUnlock()
rr := findRewrites(d.Rewrites, host, qtype)
if len(rr) != 0 {
res.Reason = Rewritten
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
if !matched {
return Result{}
}
res.Reason = Rewritten
cnames := stringutil.NewSet()
origHost := host
for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer)
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
rw := rewrites[0]
rwPat := rw.Domain
rwAns := rw.Answer
if host == rr[0].Answer { // "host == CNAME" is an exception
res.Reason = NotFilteredNotFound
log.Debug("rewrite: cname for %s is %s", host, rwAns)
return res
if origHost == rwAns || rwPat == rwAns {
// Either a request for the hostname itself or a rewrite of
// a pattern onto itself, both of which are an exception rules.
// Return a not filtered result.
return Result{}
} else if host == rwAns && isWildcard(rwPat) {
// An "*.example.com → sub.example.com" rewrite matching in a loop.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
res.CanonName = host
break
}
host = rr[0].Answer
host = rwAns
if cnames.Has(host) {
log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
log.Info("rewrite: cname loop for %q on %q", origHost, host)
return res
}
cnames.Add(host)
res.CanonName = rr[0].Answer
rr = findRewrites(d.Rewrites, host, qtype)
res.CanonName = host
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
}
for _, r := range rr {
if r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if r.IP == nil { // IP exception
res.Reason = NotFilteredNotFound
return res
}
res.IPList = append(res.IPList, r.IP)
log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP)
}
}
setRewriteResult(&res, host, rewrites, qtype)
return res
}
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
// be nil.
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
for _, rw := range rewrites {
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rw.IP == nil {
// "A"/"AAAA" exception: allow getting from upstream.
res.Reason = NotFilteredNotFound
return
}
res.IPList = append(res.IPList, rw.IP)
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
}
}
}
// matchBlockedServicesRules checks the host against the blocked services rules
// in settings, if any. The err is always nil, it is only there to make this
// a valid hostChecker function.
@@ -548,6 +585,10 @@ func matchBlockedServicesRules(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled {
return Result{}, nil
}
svcs := setts.ServicesRules
if len(svcs) == 0 {
return Result{}, nil
@@ -582,80 +623,82 @@ func matchBlockedServicesRules(
// Adding rule and matching against the rules
//
// fileExists returns true if file exists.
func fileExists(fn string) bool {
_, err := os.Stat(fn)
return err == nil
}
func createFilteringEngine(filters []Filter) (*filterlist.RuleStorage, *urlfilter.DNSEngine, error) {
listArray := []filterlist.RuleList{}
func newRuleStorage(filters []Filter) (rs *filterlist.RuleStorage, err error) {
lists := make([]filterlist.RuleList, 0, len(filters))
for _, f := range filters {
var list filterlist.RuleList
if f.ID == 0 {
list = &filterlist.StringRuleList{
ID: 0,
switch id := int(f.ID); {
case len(f.Data) != 0:
lists = append(lists, &filterlist.StringRuleList{
ID: id,
RulesText: string(f.Data),
IgnoreCosmetic: true,
}
} else if !fileExists(f.FilePath) {
list = &filterlist.StringRuleList{
ID: int(f.ID),
IgnoreCosmetic: true,
}
} else if runtime.GOOS == "windows" {
// On Windows we don't pass a file to urlfilter because
// it's difficult to update this file while it's being
// used.
data, err := os.ReadFile(f.FilePath)
if err != nil {
return nil, nil, fmt.Errorf("reading filter content: %w", err)
})
case f.FilePath == "":
continue
case runtime.GOOS == "windows":
// On Windows we don't pass a file to urlfilter because it's
// difficult to update this file while it's being used.
var data []byte
data, err = os.ReadFile(f.FilePath)
if errors.Is(err, fs.ErrNotExist) {
continue
} else if err != nil {
return nil, fmt.Errorf("reading filter content: %w", err)
}
list = &filterlist.StringRuleList{
ID: int(f.ID),
lists = append(lists, &filterlist.StringRuleList{
ID: id,
RulesText: string(data),
IgnoreCosmetic: true,
})
default:
var list *filterlist.FileRuleList
list, err = filterlist.NewFileRuleList(id, f.FilePath, true)
if errors.Is(err, fs.ErrNotExist) {
continue
} else if err != nil {
return nil, fmt.Errorf("creating file rule list with %q: %w", f.FilePath, err)
}
} else {
var err error
list, err = filterlist.NewFileRuleList(int(f.ID), f.FilePath, true)
if err != nil {
return nil, nil, fmt.Errorf("filterlist.NewFileRuleList(): %s: %w", f.FilePath, err)
}
lists = append(lists, list)
}
listArray = append(listArray, list)
}
rulesStorage, err := filterlist.NewRuleStorage(listArray)
rs, err = filterlist.NewRuleStorage(lists)
if err != nil {
return nil, nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
return nil, fmt.Errorf("creating rule storage: %w", err)
}
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
return rulesStorage, filteringEngine, nil
return rs, nil
}
// Initialize urlfilter objects.
func (d *DNSFilter) initFiltering(allowFilters, blockFilters []Filter) error {
rulesStorage, filteringEngine, err := createFilteringEngine(blockFilters)
if err != nil {
return err
}
rulesStorageAllow, filteringEngineAllow, err := createFilteringEngine(allowFilters)
rulesStorage, err := newRuleStorage(blockFilters)
if err != nil {
return err
}
d.engineLock.Lock()
d.reset()
d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine
d.rulesStorageAllow = rulesStorageAllow
d.filteringEngineAllow = filteringEngineAllow
d.engineLock.Unlock()
rulesStorageAllow, err := newRuleStorage(allowFilters)
if err != nil {
return err
}
// Make sure that the OS reclaims memory as soon as possible
filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
filteringEngineAllow := urlfilter.NewDNSEngine(rulesStorageAllow)
func() {
d.engineLock.Lock()
defer d.engineLock.Unlock()
d.reset()
d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine
d.rulesStorageAllow = rulesStorageAllow
d.filteringEngineAllow = filteringEngineAllow
}()
// Make sure that the OS reclaims memory as soon as possible.
debug.FreeOSMemory()
log.Debug("initialized filtering engine")
@@ -677,11 +720,10 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) {
return res
}
// matchHostProcessAllowList processes the allowlist logic of host
// matching.
// matchHostProcessAllowList processes the allowlist logic of host matching.
func (d *DNSFilter) matchHostProcessAllowList(
host string,
dnsres urlfilter.DNSResult,
dnsres *urlfilter.DNSResult,
) (res Result, err error) {
var matchedRules []rules.Rule
if dnsres.NetworkRule != nil {
@@ -704,7 +746,7 @@ func (d *DNSFilter) matchHostProcessAllowList(
// matchHostProcessDNSResult processes the matched DNS filtering result.
func (d *DNSFilter) matchHostProcessDNSResult(
qtype uint16,
dnsres urlfilter.DNSResult,
dnsres *urlfilter.DNSResult,
) (res Result) {
if dnsres.NetworkRule != nil {
reason := FilteredBlockList
@@ -734,8 +776,8 @@ func (d *DNSFilter) matchHostProcessDNSResult(
}
if dnsres.HostRulesV4 != nil || dnsres.HostRulesV6 != nil {
// Question type doesn't match the host rules. Return the first
// matched host rule, but without an IP address.
// Question type doesn't match the host rules. Return the first matched
// host rule, but without an IP address.
var matchedRules []rules.Rule
if dnsres.HostRulesV4 != nil {
matchedRules = []rules.Rule{dnsres.HostRulesV4[0]}
@@ -749,32 +791,34 @@ func (d *DNSFilter) matchHostProcessDNSResult(
return Result{}
}
// matchHost is a low-level way to check only if hostname is filtered by rules,
// matchHost is a low-level way to check only if host is filtered by rules,
// skipping expensive safebrowsing and parental lookups.
func (d *DNSFilter) matchHost(
host string,
qtype uint16,
rrtype uint16,
setts *Settings,
) (res Result, err error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match()
// but also while using the rules returned by it.
defer d.engineLock.RUnlock()
ureq := urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
DNSType: rrtype,
}
if d.filteringEngineAllow != nil {
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match() but
// also while using the rules returned by it.
//
// TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock()
if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok {
return d.matchHostProcessAllowList(host, dnsres)
@@ -786,13 +830,12 @@ func (d *DNSFilter) matchHost(
}
dnsres, ok := d.filteringEngine.MatchRequest(ureq)
// Check DNS rewrites first, because the API there is a bit awkward.
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
res = d.processDNSRewrites(dnsr)
if res.Reason == RewrittenRule && res.CanonName == host {
// A rewrite of a host to itself. Go on and try
// matching other things.
// A rewrite of a host to itself. Go on and try matching other
// things.
} else {
return res, nil
}
@@ -800,7 +843,12 @@ func (d *DNSFilter) matchHost(
return Result{}, nil
}
res = d.matchHostProcessDNSResult(qtype, dnsres)
if !setts.ProtectionEnabled {
// Don't check non-dnsrewrite filtering results.
return Result{}, nil
}
res = d.matchHostProcessDNSResult(rrtype, dnsres)
for _, r := range res.Rules {
log.Debug(
"filtering: found rule %q for host %q, filter list id: %d",
@@ -836,40 +884,33 @@ func InitModule() {
}
// New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) *DNSFilter {
var resolver Resolver = net.DefaultResolver
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
d = &DNSFilter{
resolver: net.DefaultResolver,
}
if c != nil {
cacheConf := cache.Config{
d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true,
}
if gctx.safebrowsingCache == nil {
cacheConf.MaxSize = c.SafeBrowsingCacheSize
gctx.safebrowsingCache = cache.New(cacheConf)
}
if gctx.safeSearchCache == nil {
cacheConf.MaxSize = c.SafeSearchCacheSize
gctx.safeSearchCache = cache.New(cacheConf)
}
if gctx.parentalCache == nil {
cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf)
}
MaxSize: c.SafeBrowsingCacheSize,
})
d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeSearchCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
if c.CustomResolver != nil {
resolver = c.CustomResolver
d.resolver = c.CustomResolver
}
}
d := &DNSFilter{
resolver: resolver,
}
d.hostCheckers = []hostChecker{{
check: d.checkEtcHosts,
name: "etchosts",
check: d.matchSysHosts,
name: "hosts container",
}, {
check: d.matchHost,
name: "filtering",
@@ -890,12 +931,18 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
err := d.initSecurityServices()
if err != nil {
log.Error("filtering: initialize services: %s", err)
return nil
}
if c != nil {
d.Config = *c
d.prepareRewrites()
err = d.prepareRewrites()
if err != nil {
log.Error("rewrites: preparing: %s", err)
return nil
}
}
bsvcs := []string{}

View File

@@ -21,15 +21,17 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
var setts Settings
var setts = Settings{
ProtectionEnabled: true,
}
// Helpers.
func purgeCaches() {
func purgeCaches(d *DNSFilter) {
for _, c := range []cache.Cache{
gctx.safebrowsingCache,
gctx.parentalCache,
gctx.safeSearchCache,
d.safebrowsingCache,
d.parentalCache,
d.safeSearchCache,
} {
if c != nil {
c.Clear()
@@ -37,11 +39,11 @@ func purgeCaches() {
}
}
func newForTest(c *Config, filters []Filter) *DNSFilter {
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
setts = Settings{
FilteringEnabled: true,
ProtectionEnabled: true,
FilteringEnabled: true,
}
setts.FilteringEnabled = true
if c != nil {
c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000
@@ -52,7 +54,8 @@ func newForTest(c *Config, filters []Filter) *DNSFilter {
setts.ParentalEnabled = c.ParentalEnabled
}
d := New(c, filters)
purgeCaches()
purgeCaches(d)
return d
}
@@ -103,7 +106,7 @@ func TestEtcHostsMatching(t *testing.T) {
filters := []Filter{{
ID: 0, Data: []byte(text),
}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
t.Cleanup(d.Close)
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
@@ -168,7 +171,7 @@ func TestSafeBrowsing(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -191,7 +194,7 @@ func TestSafeBrowsing(t *testing.T) {
}
func TestParallelSB(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -215,7 +218,7 @@ func TestParallelSB(t *testing.T) {
// Safe Search.
func TestSafeSearch(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok)
@@ -224,7 +227,9 @@ func TestSafeSearch(t *testing.T) {
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(t, &Config{
SafeSearchEnabled: true,
}, nil)
t.Cleanup(d.Close)
yandexIP := net.IPv4(213, 180, 193, 56)
@@ -247,13 +252,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
d := newForTest(t, &Config{
SafeSearchEnabled: true,
CustomResolver: resolver,
}, nil)
@@ -280,12 +286,13 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
assert.EqualValues(t, SafeSearchListID, res.Rules[0].FilterListID)
})
}
}
func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
const domain = "yandex.ru"
@@ -299,7 +306,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts)
@@ -310,7 +317,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
@@ -319,7 +326,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
func TestSafeSearchCacheGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(&Config{
d := newForTest(t, &Config{
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
@@ -332,7 +339,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
require.Empty(t, res.Rules)
d = newForTest(&Config{SafeSearchEnabled: true}, nil)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
d.resolver = resolver
@@ -359,7 +366,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
assert.True(t, res.Rules[0].IP.Equal(ip))
// Check cache.
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, domain)
cachedValue, isFound := getCachedResult(d.safeSearchCache, domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
@@ -373,7 +380,7 @@ func TestParentalControl(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(&Config{ParentalEnabled: true}, nil)
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
@@ -677,7 +684,7 @@ func TestMatching(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
@@ -703,7 +710,7 @@ func TestWhitelist(t *testing.T) {
whiteFilters := []Filter{{
ID: 0, Data: []byte(whiteRules),
}}
d := newForTest(nil, filters)
d := newForTest(t, nil, filters)
err := d.SetFilters(filters, whiteFilters, false)
require.NoError(t, err)
@@ -748,7 +755,7 @@ func applyClientSettings(setts *Settings) {
}
func TestClientSettings(t *testing.T) {
d := newForTest(
d := newForTest(t,
&Config{
ParentalEnabled: true,
SafeBrowsingEnabled: false,
@@ -797,7 +804,11 @@ func TestClientSettings(t *testing.T) {
makeTester := func(tc testCase, before bool) func(t *testing.T) {
return func(t *testing.T) {
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts)
t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
require.NoError(t, err)
if before {
assert.True(t, r.IsFiltered)
assert.Equal(t, tc.wantReason, r.Reason)
@@ -808,7 +819,7 @@ func TestClientSettings(t *testing.T) {
}
// Check behaviour without any per-client settings, then apply per-client
// settings and check behaviour once again.
// settings and check behavior once again.
for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, tc.before))
}
@@ -823,7 +834,7 @@ func TestClientSettings(t *testing.T) {
// Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -839,7 +850,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}
func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
@@ -857,7 +868,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
}
func BenchmarkSafeSearch(b *testing.B) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
@@ -868,7 +879,7 @@ func BenchmarkSafeSearch(b *testing.B) {
}
func BenchmarkSafeSearchParallel(b *testing.B) {
d := newForTest(&Config{SafeSearchEnabled: true}, nil)
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {

View File

@@ -4,93 +4,121 @@ package filtering
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// RewriteEntry is a rewrite array element
type RewriteEntry struct {
// Domain is the domain for which this rewrite should work.
// LegacyRewrite is a single legacy DNS rewrite record.
//
// Instances of *LegacyRewrite must never be nil.
type LegacyRewrite struct {
// Domain is the domain pattern for which this rewrite should work.
Domain string `yaml:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer"`
// IP is the IP address that should be used in the response if Type is
// A or AAAA.
// dns.TypeA or dns.TypeAAAA.
IP net.IP `yaml:"-"`
// Type is the DNS record type: A, AAAA, or CNAME.
Type uint16 `yaml:"-"`
}
// equal returns true if the entry is considered equal to the other.
func (e *RewriteEntry) equal(other RewriteEntry) (ok bool) {
return e.Domain == other.Domain && e.Answer == other.Answer
// clone returns a deep clone of rw.
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
return &LegacyRewrite{
Domain: rw.Domain,
Answer: rw.Answer,
IP: netutil.CloneIP(rw.IP),
Type: rw.Type,
}
}
// matchesQType returns true if the entry matched qtype.
func (e *RewriteEntry) matchesQType(qtype uint16) (ok bool) {
// equal returns true if the rw is equal to the other.
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
return rw.Domain == other.Domain && rw.Answer == other.Answer
}
// matchesQType returns true if the entry matches the question type qt.
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
// Add CNAMEs, since they match for all types requests.
if e.Type == dns.TypeCNAME {
if rw.Type == dns.TypeCNAME {
return true
}
// Reject types other than A and AAAA.
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
if qt != dns.TypeA && qt != dns.TypeAAAA {
return false
}
// If the types match or the entry is set to allow only the other type,
// include them.
return e.Type == qtype || e.IP == nil
return rw.Type == qt || rw.IP == nil
}
// normalize makes sure that the a new or decoded entry is normalized with
// regards to domain name case, IP length, and so on.
func (e *RewriteEntry) normalize() {
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix
// and use it in matchDomainWildcard instead of using strings.ToLower
//
// If rw is nil, it returns an errors.
func (rw *LegacyRewrite) normalize() (err error) {
if rw == nil {
return errors.Error("nil rewrite entry")
}
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
// use it in matchDomainWildcard instead of using strings.ToLower
// everywhere.
e.Domain = strings.ToLower(e.Domain)
rw.Domain = strings.ToLower(rw.Domain)
switch e.Answer {
switch rw.Answer {
case "AAAA":
e.IP = nil
e.Type = dns.TypeAAAA
rw.IP = nil
rw.Type = dns.TypeAAAA
return
return nil
case "A":
e.IP = nil
e.Type = dns.TypeA
rw.IP = nil
rw.Type = dns.TypeA
return
return nil
default:
// Go on.
}
ip := net.ParseIP(e.Answer)
ip := net.ParseIP(rw.Answer)
if ip == nil {
e.Type = dns.TypeCNAME
rw.Type = dns.TypeCNAME
return
return nil
}
ip4 := ip.To4()
if ip4 != nil {
e.IP = ip4
e.Type = dns.TypeA
rw.IP = ip4
rw.Type = dns.TypeA
} else {
e.IP = ip
e.Type = dns.TypeAAAA
rw.IP = ip
rw.Type = dns.TypeAAAA
}
return nil
}
func isWildcard(host string) bool {
return len(host) > 1 && host[0] == '*' && host[1] == '.'
// isWildcard returns true if pat is a wildcard domain pattern.
func isWildcard(pat string) bool {
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
}
// matchDomainWildcard returns true if host matches the wildcard pattern.
@@ -106,16 +134,16 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
// wildcard > exact
// lower level wildcard > higher level wildcard
//
type rewritesSorted []RewriteEntry
type rewritesSorted []*LegacyRewrite
// Len implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Len() int { return len(a) }
func (a rewritesSorted) Len() (l int) { return len(a) }
// Swap implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// Less implements the sort.Interface interface for legacyRewritesSorted.
func (a rewritesSorted) Less(i, j int) bool {
func (a rewritesSorted) Less(i, j int) (less bool) {
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
return true
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
@@ -132,49 +160,62 @@ func (a rewritesSorted) Less(i, j int) bool {
}
}
// both are wildcards
// Both are wildcards.
return len(a[i].Domain) > len(a[j].Domain)
}
func (d *DNSFilter) prepareRewrites() {
for i := range d.Rewrites {
d.Rewrites[i].normalize()
// prepareRewrites normalizes and validates all legacy DNS rewrites.
func (d *DNSFilter) prepareRewrites() (err error) {
for i, r := range d.Rewrites {
err = r.normalize()
if err != nil {
return fmt.Errorf("at index %d: %w", i, err)
}
}
return nil
}
// findRewrites returns the list of matched rewrite entries. The priority is:
// CNAME, then A and AAAA; exact, then wildcard. If the host is matched
// exactly, wildcard entries aren't returned. If the host matched by wildcards,
// return the most specific for the question type.
func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) {
rr := rewritesSorted{}
// findRewrites returns the list of matched rewrite entries. If rewrites are
// empty, but matched is true, the domain is found among the rewrite rules but
// not for this question type.
//
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
// host is matched exactly, wildcard entries aren't returned. If the host
// matched by wildcards, return the most specific for the question type.
func findRewrites(
entries []*LegacyRewrite,
host string,
qtype uint16,
) (rewrites []*LegacyRewrite, matched bool) {
for _, e := range entries {
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
continue
}
matched = true
if e.matchesQType(qtype) {
rr = append(rr, e)
rewrites = append(rewrites, e)
}
}
if len(rr) == 0 {
return nil
if len(rewrites) == 0 {
return nil, matched
}
sort.Sort(rr)
sort.Sort(rewritesSorted(rewrites))
for i, r := range rr {
for i, r := range rewrites {
if isWildcard(r.Domain) {
// Don't use rr[:0], because we need to return at least
// one item here.
rr = rr[:max(1, i)]
// Don't use rewrites[:0], because we need to return at least one
// item here.
rewrites = rewrites[:max(1, i)]
break
}
}
return rr
return rewrites, matched
}
func max(a, b int) int {
@@ -206,29 +247,39 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(arr)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
rwJSON := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&rwJSON)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ent := RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
rw := &LegacyRewrite{
Domain: rwJSON.Domain,
Answer: rwJSON.Answer,
}
ent.normalize()
err = rw.normalize()
if err != nil {
// Shouldn't happen currently, since normalize only returns a non-nil
// error when a rewrite is nil, but be change-proof.
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
return
}
d.confLock.Lock()
d.Config.Rewrites = append(d.Config.Rewrites, ent)
d.Config.Rewrites = append(d.Config.Rewrites, rw)
d.confLock.Unlock()
log.Debug("Rewrites: added element: %s -> %s [%d]",
ent.Domain, ent.Answer, len(d.Config.Rewrites))
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
d.Config.ConfigModified()
}
@@ -237,21 +288,25 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
entDel := RewriteEntry{
entDel := &LegacyRewrite{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
arr := []RewriteEntry{}
arr := []*LegacyRewrite{}
d.confLock.Lock()
for _, ent := range d.Config.Rewrites {
if ent.equal(entDel) {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
}
arr = append(arr, ent)
}
d.Config.Rewrites = arr

View File

@@ -12,10 +12,10 @@ import (
// TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
// This one and below are about CNAME, A and AAAA.
Domain: "somecname",
Answer: "somehost.com",
@@ -66,107 +66,132 @@ func TestRewrites(t *testing.T) {
}, {
Domain: "BIGHOST.COM",
Answer: "1.2.3.7",
}, {
Domain: "*.issue4016.com",
Answer: "sub.issue4016.com",
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
host string
wantCName string
wantVals []net.IP
dtyp uint16
name string
host string
wantCName string
wantIPs []net.IP
wantReason Reason
dtyp uint16
}{{
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantVals: nil,
dtyp: dns.TypeA,
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}, {
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
dtyp: dns.TypeAAAA,
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 4}},
dtyp: dns.TypeA,
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 4}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantVals: []net.IP{{0, 0, 0, 0}},
dtyp: dns.TypeA,
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantIPs: []net.IP{{0, 0, 0, 0}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantVals: []net.IP{net.ParseIP("1234::5678")},
dtyp: dns.TypeAAAA,
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantIPs: []net.IP{net.ParseIP("1234::5678")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 7}},
dtyp: dns.TypeA,
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 7}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue4008",
host: "somehost.com",
wantCName: "",
wantIPs: nil,
wantReason: Rewritten,
dtyp: dns.TypeHTTPS,
}, {
name: "issue4016",
host: "www.issue4016.com",
wantCName: "sub.issue4016.com",
wantIPs: nil,
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue4016_self",
host: "sub.issue4016.com",
wantCName: "",
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
valsNum := len(tc.wantVals)
r := d.processRewrites(tc.host, tc.dtyp)
if valsNum == 0 {
assert.Equal(t, NotFilteredNotFound, r.Reason)
return
}
require.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
if tc.wantCName != "" {
assert.Equal(t, tc.wantCName, r.CanonName)
}
require.Len(t, r.IPList, valsNum)
for i, ip := range tc.wantVals {
assert.Equal(t, ip, r.IPList[i])
}
assert.Equal(t, tc.wantIPs, r.IPList)
})
}
}
func TestRewritesLevels(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exact host, wildcard L2, wildcard L3.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "host.com",
Answer: "1.1.1.1",
Type: dns.TypeA,
@@ -179,7 +204,8 @@ func TestRewritesLevels(t *testing.T) {
Answer: "3.3.3.3",
Type: dns.TypeA,
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
@@ -209,10 +235,10 @@ func TestRewritesLevels(t *testing.T) {
}
func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Wildcard and exception for a sub-domain.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "*.host.com",
Answer: "2.2.2.2",
}, {
@@ -222,29 +248,32 @@ func TestRewritesExceptionCNAME(t *testing.T) {
Domain: "*.sub.host.com",
Answer: "*.sub.host.com",
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string
host string
want net.IP
}{{
name: "match_sub-domain",
name: "match_subdomain",
host: "my.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "exception_cname",
host: "sub.host.com",
want: nil,
}, {
name: "exception_wildcard",
host: "my.sub.host.com",
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := d.processRewrites(tc.host, dns.TypeA)
if tc.want == nil {
assert.Equal(t, NotFilteredNotFound, r.Reason)
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
return
}
@@ -257,10 +286,10 @@ func TestRewritesExceptionCNAME(t *testing.T) {
}
func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(nil, nil)
d := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exception for AAAA record.
d.Rewrites = []RewriteEntry{{
d.Rewrites = []*LegacyRewrite{{
Domain: "host.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
@@ -281,7 +310,8 @@ func TestRewritesExceptionIP(t *testing.T) {
Answer: "A",
Type: dns.TypeA,
}}
d.prepareRewrites()
require.NoError(t, d.prepareRewrites())
testCases := []struct {
name string

View File

@@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
@@ -306,7 +307,7 @@ func (d *DNSFilter) checkSafeBrowsing(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeBrowsingEnabled {
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil
}
@@ -318,7 +319,7 @@ func (d *DNSFilter) checkSafeBrowsing(
sctx := &sbCtx{
host: host,
svc: "SafeBrowsing",
cache: gctx.safebrowsingCache,
cache: d.safebrowsingCache,
cacheTime: d.Config.CacheTime,
}
@@ -326,7 +327,8 @@ func (d *DNSFilter) checkSafeBrowsing(
IsFiltered: true,
Reason: FilteredSafeBrowsing,
Rules: []*ResultRule{{
Text: "adguard-malware-shavar",
Text: "adguard-malware-shavar",
FilterListID: SafeBrowsingListID,
}},
}
@@ -339,7 +341,7 @@ func (d *DNSFilter) checkParental(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ParentalEnabled {
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil
}
@@ -351,7 +353,7 @@ func (d *DNSFilter) checkParental(
sctx := &sbCtx{
host: host,
svc: "Parental",
cache: gctx.parentalCache,
cache: d.parentalCache,
cacheTime: d.Config.CacheTime,
}
@@ -359,19 +361,14 @@ func (d *DNSFilter) checkParental(
IsFiltered: true,
Reason: FilteredParental,
Rules: []*ResultRule{{
Text: "parental CATEGORY_BLACKLISTED",
Text: "parental CATEGORY_BLACKLISTED",
FilterListID: ParentalListID,
}},
}
return check(sctx, res, d.parentalUpstream)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeBrowsingEnabled = true
d.Config.ConfigModified()
@@ -390,7 +387,8 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
Enabled: d.Config.SafeBrowsingEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@@ -413,8 +411,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
Enabled: d.Config.ParentalEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}

View File

@@ -108,7 +108,7 @@ func TestSafeBrowsingCache(t *testing.T) {
}
func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := &aghtest.TestErrUpstream{}
@@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetParentalUpstream(ups)
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
@@ -129,41 +130,42 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
}
func TestSBPC(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const hostname = "example.org"
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string
block bool
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
testCache cache.Cache
}{{
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block",
block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block",
block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block",
block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_block",
block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}}
for _, tc := range testCases {
@@ -215,6 +217,6 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, 1, ups.RequestsCount())
})
purgeCaches()
purgeCaches(d)
}
}

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
@@ -74,7 +75,7 @@ func (d *DNSFilter) checkSafeSearch(
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.SafeSearchEnabled {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil
}
@@ -84,7 +85,7 @@ func (d *DNSFilter) checkSafeSearch(
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
cachedValue, isFound := getCachedResult(d.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
@@ -99,12 +100,14 @@ func (d *DNSFilter) checkSafeSearch(
res = Result{
IsFiltered: true,
Reason: FilteredSafeSearch,
Rules: []*ResultRule{{}},
Rules: []*ResultRule{{
FilterListID: SafeSearchListID,
}},
}
if ip := net.ParseIP(safeHost); ip != nil {
res.Rules[0].IP = ip
valLen := d.setCacheResult(gctx.safeSearchCache, host, res)
valLen := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen)
return res, nil
@@ -123,7 +126,7 @@ func (d *DNSFilter) checkSafeSearch(
res.Rules[0].IP = ip
l := d.setCacheResult(gctx.safeSearchCache, host, res)
l := d.setCacheResult(d.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, l)
return res, nil
@@ -150,8 +153,13 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
Enabled: d.Config.SafeSearchEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to write response json: %s",
err,
)
}
}

View File

@@ -13,6 +13,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
@@ -403,8 +404,8 @@ func realIP(r *http.Request) (ip net.IP, err error) {
return ip, nil
}
// When everything else fails, just return the remote address as
// understood by the stdlib.
// When everything else fails, just return the remote address as understood
// by the stdlib.
ipStr, err := netutil.SplitHost(r.RemoteAddr)
if err != nil {
return nil, fmt.Errorf("getting ip from client addr: %w", err)
@@ -417,19 +418,20 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
req := loginJSON{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
var remoteAddr string
// The realIP couldn't be used here due to security issues.
// realIP cannot be used here without taking TrustedProxies into account due
// to security issues.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
//
// TODO(e.burkov): Use realIP when the issue will be fixed.
if remoteAddr, err = netutil.SplitHost(r.RemoteAddr); err != nil {
httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "auth: getting remote address: %s", err)
return
}
@@ -437,12 +439,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
if blocker := Context.auth.blocker; blocker != nil {
if left := blocker.check(remoteAddr); left > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
httpError(
w,
http.StatusTooManyRequests,
"auth: blocked for %s",
left,
)
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
return
}
@@ -451,22 +448,23 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
var cookie string
cookie, err = Context.auth.httpCookie(req, remoteAddr)
if err != nil {
httpError(w, http.StatusBadRequest, "crypto rand reader: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err)
return
}
// Use realIP here, since this IP address is only used for logging.
ip, err := realIP(r)
if err != nil {
log.Error("auth: getting real ip from request: %s", err)
} else if ip == nil {
// Technically shouldn't happen.
log.Error("auth: unknown ip")
}
if len(cookie) == 0 {
var ip net.IP
ip, err = realIP(r)
if err != nil {
log.Info("auth: getting real ip from request: %s", err)
} else if ip == nil {
// Technically shouldn't happen.
log.Info("auth: failed to login user %q from unknown ip", req.Name)
} else {
log.Info("auth: failed to login user %q from ip %q", req.Name, ip)
}
log.Info("auth: failed to login user %q from ip %v", req.Name, ip)
time.Sleep(1 * time.Second)
http.Error(w, "invalid username or password", http.StatusBadRequest)
@@ -474,13 +472,15 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Set("Set-Cookie", cookie)
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
h := w.Header()
h.Set("Set-Cookie", cookie)
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
h.Set("Pragma", "no-cache")
h.Set("Expires", "0")
returnOK(w)
aghhttp.OK(w)
}
func handleLogout(w http.ResponseWriter, r *http.Request) {
@@ -528,7 +528,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
cookie, err := r.Cookie(sessionCookieName)
if glProcessCookie(r) {
log.Debug("auth: authentification was handled by GL-Inet submodule")
log.Debug("auth: authentication was handled by GL-Inet submodule")
ok = true
} else if err == nil {
r := Context.auth.checkSession(cookie.Value)

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -24,19 +25,17 @@ func TestMain(m *testing.M) {
func TestNewSessionToken(t *testing.T) {
// Successful case.
token, err := newSessionToken()
require.Nil(t, err)
require.NoError(t, err)
assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader.
prevReader := rand.Reader
t.Cleanup(func() {
rand.Reader = prevReader
})
t.Cleanup(func() { rand.Reader = prevReader })
rand.Reader = &bytes.Buffer{}
// Unsuccessful case.
token, err = newSessionToken()
require.NotNil(t, err)
require.Error(t, err)
assert.Empty(t, token)
}
@@ -58,7 +57,7 @@ func TestAuth(t *testing.T) {
a.RemoveSession("notfound")
sess, err := newSessionToken()
assert.Nil(t, err)
require.NoError(t, err)
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()
@@ -152,7 +151,7 @@ func TestAuthHTTP(t *testing.T) {
// perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
assert.Nil(t, err)
require.NoError(t, err)
assert.NotEmpty(t, cookie)
// get /
@@ -251,12 +250,7 @@ func TestRealIP(t *testing.T) {
ip, err := realIP(r)
assert.Equal(t, tc.wantIP, ip)
if tc.wantErrMsg == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@@ -15,21 +15,19 @@ func TestAuthGL(t *testing.T) {
dir := t.TempDir()
GLMode = true
t.Cleanup(func() {
GLMode = false
})
t.Cleanup(func() { GLMode = false })
glFilePrefix = dir + "/gl_token_"
data := make([]byte, 4)
aghos.NativeEndian.PutUint32(data, 1)
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
assert.False(t, glCheckToken("test"))
data = make([]byte, 4)
aghos.NativeEndian.PutUint32(data, uint32(time.Now().UTC().Unix()+60))
require.Nil(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
require.NoError(t, os.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r))

View File

@@ -95,7 +95,9 @@ type clientsContainer struct {
// dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server
etcHosts *aghnet.EtcHostsContainer // get entries from system hosts-files
// etcHosts contains list of rewrite rules taken from the operating system's
// hosts database.
etcHosts *aghnet.HostsContainer
testing bool // if TRUE, this object is used for internal tests
}
@@ -104,9 +106,9 @@ type clientsContainer struct {
// dhcpServer: optional
// Note: this function must be called only once
func (clients *clientsContainer) Init(
objects []clientObject,
objects []*clientObject,
dhcpServer *dhcpd.Server,
etcHosts *aghnet.EtcHostsContainer,
etcHosts *aghnet.HostsContainer,
) {
if clients.list != nil {
log.Fatal("clients.list != nil")
@@ -121,13 +123,22 @@ func (clients *clientsContainer) Init(
clients.etcHosts = etcHosts
clients.addFromConfig(objects)
if !clients.testing {
clients.updateFromDHCP(true)
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
if clients.etcHosts != nil {
clients.etcHosts.SetOnChanged(clients.onHostsChanged)
if clients.testing {
return
}
clients.updateFromDHCP(true)
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
go clients.handleHostsUpdates()
}
func (clients *clientsContainer) handleHostsUpdates() {
if clients.etcHosts != nil {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
}
}
@@ -164,56 +175,65 @@ type clientObject struct {
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
}
func (clients *clientsContainer) tagKnown(tag string) (ok bool) {
return clients.allTags.Has(tag)
}
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, cy := range objects {
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject) {
for _, o := range objects {
cli := &Client{
Name: cy.Name,
IDs: cy.IDs,
UseOwnSettings: !cy.UseGlobalSettings,
FilteringEnabled: cy.FilteringEnabled,
ParentalEnabled: cy.ParentalEnabled,
SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
Name: o.Name,
UseOwnBlockedServices: !cy.UseGlobalBlockedServices,
IDs: o.IDs,
Upstreams: o.Upstreams,
Upstreams: cy.Upstreams,
UseOwnSettings: !o.UseGlobalSettings,
FilteringEnabled: o.FilteringEnabled,
ParentalEnabled: o.ParentalEnabled,
SafeSearchEnabled: o.SafeSearchEnabled,
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
}
for _, s := range cy.BlockedServices {
if !filtering.BlockedSvcKnown(s) {
log.Debug("clients: skipping unknown blocked-service %q", s)
continue
for _, s := range o.BlockedServices {
if filtering.BlockedSvcKnown(s) {
cli.BlockedServices = append(cli.BlockedServices, s)
} else {
log.Info("clients: skipping unknown blocked service %q", s)
}
cli.BlockedServices = append(cli.BlockedServices, s)
}
for _, t := range cy.Tags {
if !clients.tagKnown(t) {
log.Debug("clients: skipping unknown tag %q", t)
continue
for _, t := range o.Tags {
if clients.allTags.Has(t) {
cli.Tags = append(cli.Tags, t)
} else {
log.Info("clients: skipping unknown tag %q", t)
}
cli.Tags = append(cli.Tags, t)
}
sort.Strings(cli.Tags)
_, err := clients.Add(cli)
if err != nil {
log.Tracef("clientAdd: %s", err)
log.Error("clients: adding clients %s: %s", cli.Name, err)
}
}
}
// WriteDiskConfig - write configuration
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
// forConfig returns all currently known persistent clients as objects for the
// configuration file.
func (clients *clientsContainer) forConfig() (objs []*clientObject) {
clients.lock.Lock()
defer clients.lock.Unlock()
objs = make([]*clientObject, 0, len(clients.list))
for _, cli := range clients.list {
cy := clientObject{
Name: cli.Name,
o := &clientObject{
Name: cli.Name,
Tags: stringutil.CloneSlice(cli.Tags),
IDs: stringutil.CloneSlice(cli.IDs),
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
@@ -222,14 +242,16 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
}
cy.Tags = stringutil.CloneSlice(cli.Tags)
cy.IDs = stringutil.CloneSlice(cli.IDs)
cy.BlockedServices = stringutil.CloneSlice(cli.BlockedServices)
cy.Upstreams = stringutil.CloneSlice(cli.Upstreams)
*objects = append(*objects, cy)
objs = append(objs, o)
}
clients.lock.Unlock()
// Maps aren't guaranteed to iterate in the same order each time, so the
// above loop can generate different orderings when writing to the config
// file: this produces lots of diffs in config files, so sort objects by
// name before writing.
sort.Slice(objs, func(i, j int) bool { return objs[i].Name < objs[j].Name })
return objs
}
func (clients *clientsContainer) periodicUpdate() {
@@ -250,10 +272,6 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
}
}
func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this IP address already exists.
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
clients.lock.Lock()
@@ -514,12 +532,12 @@ func (clients *clientsContainer) check(c *Client) (err error) {
} else if err = dnsforward.ValidateClientID(id); err == nil {
c.IDs[i] = id
} else {
return fmt.Errorf("invalid client id at index %d: %q", i, id)
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
}
}
for _, t := range c.Tags {
if !clients.tagKnown(t) {
if !clients.allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
@@ -697,7 +715,7 @@ func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock()
@@ -757,13 +775,7 @@ func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile() {
if clients.etcHosts == nil {
return
}
hosts := clients.etcHosts.List()
func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -771,17 +783,20 @@ func (clients *clientsContainer) addFromHostsFile() {
n := 0
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
names, ok := v.([]string)
hosts, ok := v.(*stringutil.Set)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
return true
}
for _, name := range names {
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
if ok {
hosts.Range(func(name string) (cont bool) {
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
n++
}
}
return true
})
return true
})

View File

@@ -3,10 +3,12 @@ package home
import (
"net"
"os"
"runtime"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -271,12 +273,18 @@ func TestClientsAddExisting(t *testing.T) {
})
t.Run("complicated", func(t *testing.T) {
// TODO(a.garipov): Properly decouple the DHCP server from the client
// storage.
if runtime.GOOS == "windows" {
t.Skip("skipping dhcp test on windows")
}
var err error
ip := net.IP{1, 2, 3, 4}
// First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{
config := &dhcpd.ServerConfig{
Enabled: true,
DBFilePath: "leases.db",
Conf4: dhcpd.V4ServerConf{
@@ -290,8 +298,9 @@ func TestClientsAddExisting(t *testing.T) {
clients.dhcpServer, err = dhcpd.Create(config)
require.NoError(t, err)
t.Cleanup(func() { _ = os.Remove("leases.db") })
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return os.Remove("leases.db")
})
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
@@ -309,8 +318,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP
// range.
// Add a new client with the IP from the first client's IP range.
ok, err = clients.Add(&Client{
IDs: []string{"2.2.2.2"},
Name: "client3",

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
)
@@ -58,7 +59,7 @@ type clientListJSON struct {
}
// respond with information about configured clients
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) {
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{}
clients.lock.Lock()
@@ -106,7 +107,14 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
return
}
}
@@ -154,7 +162,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
@@ -162,11 +170,14 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
c := jsonToClient(cj)
ok, err := clients.Add(c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Client already exists")
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
return
}
@@ -178,19 +189,19 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
if len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "client's name must be non-empty")
aghhttp.Error(r, w, http.StatusBadRequest, "client's name must be non-empty")
return
}
if !clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found")
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return
}
@@ -207,20 +218,22 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request")
aghhttp.Error(r, w, http.StatusBadRequest, "Invalid request")
return
}
c := jsonToClient(dj.Data)
err = clients.Update(dj.Name, c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -256,7 +269,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err)
}
}

View File

@@ -7,6 +7,7 @@ import (
"path/filepath"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -82,10 +83,12 @@ type configuration struct {
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
// Note: this array is filled only before file read/write and then it's cleared
Clients []clientObject `yaml:"clients"`
// Clients contains the YAML representations of the persistent clients.
// This field is only used for reading and writing persistent client data.
// Keep this field sorted to ensure consistent ordering.
Clients []*clientObject `yaml:"clients"`
logSettings `yaml:",inline"`
@@ -120,11 +123,6 @@ type dnsConfig struct {
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
// LocalDomainName is the domain name used for known internal hosts.
// For example, a machine called "myhost" can be addressed as
// "myhost.lan" when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
// ResolveClients enables and disables resolving clients with RDNS.
ResolveClients bool `yaml:"resolve_clients"`
@@ -140,7 +138,7 @@ type dnsConfig struct {
type tlsConfigSettings struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
PortDNSOverQUIC int `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
@@ -163,7 +161,7 @@ type tlsConfigSettings struct {
// config is the global configuration structure.
//
// TODO(a.garipov, e.burkov): This global is afwul and must be removed.
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{
BindPort: 3000,
BetaBindPort: 0,
@@ -196,7 +194,6 @@ var config = &configuration{
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
LocalDomainName: "lan",
ResolveClients: true,
UsePrivateRDNS: true,
},
@@ -205,6 +202,9 @@ var config = &configuration{
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC,
},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
},
logSettings: logSettings{
LogCompress: false,
LogLocalTime: false,
@@ -232,9 +232,9 @@ func initConfig() {
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.DHCP.Conf4.LeaseDuration = 86400
config.DHCP.Conf4.ICMPTimeout = 1000
config.DHCP.Conf6.LeaseDuration = 86400
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
config.BetaBindPort = 3001
@@ -272,20 +272,40 @@ func getLogSettings() logSettings {
}
// parseConfig loads configuration from the YAML file
func parseConfig() error {
configFile := config.getConfigFilename()
log.Debug("Reading config file: %s", configFile)
yamlFile, err := readConfigFile()
func parseConfig() (err error) {
var fileData []byte
fileData, err = readConfigFile()
if err != nil {
return err
}
config.fileData = nil
err = yaml.Unmarshal(yamlFile, &config)
err = yaml.Unmarshal(fileData, &config)
if err != nil {
log.Error("Couldn't parse config file: %s", err)
return err
}
uc := aghalg.UniqChecker{}
addPorts(
uc,
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
)
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
)
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
}
@@ -297,18 +317,26 @@ func parseConfig() error {
return nil
}
// readConfigFile reads config file contents if it exists
func readConfigFile() ([]byte, error) {
if len(config.fileData) != 0 {
// addPorts is a helper for ports validation. It skips zero ports.
func addPorts(uc aghalg.UniqChecker, ports ...int) {
for _, p := range ports {
if p != 0 {
uc.Add(p)
}
}
}
// readConfigFile reads configuration file contents.
func readConfigFile() (fileData []byte, err error) {
if len(config.fileData) > 0 {
return config.fileData, nil
}
configFile := config.getConfigFilename()
d, err := os.ReadFile(configFile)
if err != nil {
return nil, fmt.Errorf("couldn't read config file %s: %w", configFile, err)
}
return d, nil
name := config.getConfigFilename()
log.Debug("reading config file: %s", name)
// Do not wrap the error because it's informative enough as is.
return os.ReadFile(name)
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
@@ -316,8 +344,6 @@ func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
Context.clients.WriteDiskConfig(&config.Clients)
if Context.auth != nil {
config.Users = Context.auth.GetUsers()
}
@@ -360,15 +386,16 @@ func (c *configuration) write() error {
}
if Context.dhcpServer != nil {
c := dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(&c)
c := &dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(c)
config.DHCP = c
}
config.Clients = Context.clients.forConfig()
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)
config.Clients = nil
if err != nil {
log.Error("Couldn't generate YAML file: %s", err)

View File

@@ -9,6 +9,7 @@ import (
"runtime"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version"
@@ -17,23 +18,6 @@ import (
"github.com/NYTimes/gziphandler"
)
// ----------------
// helper functions
// ----------------
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info(text)
http.Error(w, text, code)
}
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS
// addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
@@ -53,7 +37,7 @@ func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
// dst. It also adds the IP addresses of all network interfaces if src contains
// an unspecified IP addresss.
// an unspecified IP address.
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) {
ifacesAdded := false
for _, h := range src {
@@ -125,12 +109,12 @@ type statusResponse struct {
Language string `json:"language"`
}
func handleStatus(w http.ResponseWriter, _ *http.Request) {
func handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses()
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
httpError(w, http.StatusInternalServerError, "%s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
@@ -165,7 +149,7 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
@@ -182,7 +166,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
data, err := json.Marshal(pj)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
}
_, _ = w.Write(data)
@@ -295,7 +279,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
host, err := netutil.SplitHost(r.Host)
if err != nil {
httpError(w, http.StatusBadRequest, "bad host: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err)
return false
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@@ -49,7 +50,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
@@ -63,7 +65,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Check for duplicates
if filterExists(fj.URL) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
@@ -79,17 +82,35 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Download the filter contents
ok, err := f.update(&filt)
if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Couldn't fetch filter from url %s: %s",
filt.URL,
err,
)
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
if !ok {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Filter at the url %s is invalid (maybe it points to blank page?)",
filt.URL,
)
return
}
// URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines.
if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
@@ -98,7 +119,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@@ -111,7 +132,8 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to parse request body json: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err)
return
}
@@ -152,7 +174,7 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
_, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't write body: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err)
}
}
@@ -172,7 +194,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
@@ -228,7 +251,8 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
@@ -250,7 +274,8 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
@@ -270,13 +295,15 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
}()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
@@ -335,13 +362,14 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
}
}
@@ -350,12 +378,14 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval")
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
@@ -404,10 +434,19 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"couldn't apply filtering: %s: %s",
host,
err,
)
return
}
@@ -432,7 +471,8 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")

View File

@@ -13,28 +13,44 @@ import (
"runtime"
"strings"
"time"
"unicode/utf8"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// getAddrsResponse is the response for /install/get_addresses endpoint.
type getAddrsResponse struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]*aghnet.NetInterface `json:"interfaces"`
// Version is the version of AdGuard Home.
//
// TODO(a.garipov): In the new API, rename this endpoint to something more
// general, since there will be more information here than just network
// interfaces.
Version string `json:"version"`
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
}
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponse{}
data.WebPort = defaultPortHTTP
data.DNSPort = defaultPortDNS
data := getAddrsResponse{
Version: version.Version(),
WebPort: defaultPortHTTP,
DNSPort: defaultPortDNS,
}
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
@@ -46,24 +62,31 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return
}
}
type checkConfigReqEnt struct {
Port int `json:"port"`
type checkConfReqEnt struct {
IP net.IP `json:"ip"`
Port int `json:"port"`
Autofix bool `json:"autofix"`
}
type checkConfigReq struct {
Web checkConfigReqEnt `json:"web"`
DNS checkConfigReqEnt `json:"dns"`
SetStaticIP bool `json:"set_static_ip"`
type checkConfReq struct {
Web checkConfReqEnt `json:"web"`
DNS checkConfReqEnt `json:"dns"`
SetStaticIP bool `json:"set_static_ip"`
}
type checkConfigRespEnt struct {
type checkConfRespEnt struct {
Status string `json:"status"`
CanAutofix bool `json:"can_autofix"`
}
@@ -74,63 +97,111 @@ type staticIPJSON struct {
Error string `json:"error"`
}
type checkConfigResp struct {
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
StaticIP staticIPJSON `json:"static_ip"`
type checkConfResp struct {
StaticIP staticIPJSON `json:"static_ip"`
Web checkConfRespEnt `json:"web"`
DNS checkConfRespEnt `json:"dns"`
}
// Check if ports are available, respond with results
// validateWeb returns error is the web part if the initial configuration can't
// be set.
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.Web.Port
addPorts(uc, config.BetaBindPort, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
// Avoid duplicating the error into the status of DNS.
uc[port] = 1
return err
}
switch port {
case 0, config.BindPort:
return nil
default:
// Go on and check the port binding only if it's not zero or won't be
// unbound after install.
}
return aghnet.CheckPort("tcp", req.Web.IP, port)
}
// validateDNS returns error if the DNS part of the initial configuration can't
// be set. canAutofix is true if the port can be unbound by AdGuard Home
// automatically.
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.DNS.Port
addPorts(uc, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return false, err
}
switch port {
case 0:
return false, nil
case config.BindPort:
// Go on and only check the UDP port since the TCP one is already bound
// by AdGuard Home for web interface.
default:
// Check TCP as well.
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
if err != nil {
return false, err
}
}
err = aghnet.CheckPort("udp", req.DNS.IP, port)
if !aghnet.IsAddrInUse(err) {
return false, err
}
// Try to fix automatically.
canAutofix = checkDNSStubListener()
if canAutofix && req.DNS.Autofix {
if derr := disableDNSStubListener(); derr != nil {
log.Error("disabling DNSStubListener: %s", err)
}
err = aghnet.CheckPort("udp", req.DNS.IP, port)
canAutofix = false
}
return canAutofix, err
}
// handleInstallCheckConfig handles the /check_config endpoint.
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReq{}
respData := checkConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData)
req := &checkConfReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "decoding the request: %s", err)
return
}
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort {
err = aghnet.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
if err != nil {
respData.Web.Status = err.Error()
}
resp := &checkConfResp{}
uc := aghalg.UniqChecker{}
if err = req.validateWeb(uc); err != nil {
resp.Web.Status = err.Error()
}
if reqData.DNS.Port != 0 {
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
if aghnet.ErrorIsAddrInUse(err) {
canAutofix := checkDNSStubListener()
if canAutofix && reqData.DNS.Autofix {
err = disableDNSStubListener()
if err != nil {
log.Error("Couldn't disable DNSStubListener: %s", err)
}
err = aghnet.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
canAutofix = false
}
respData.DNS.CanAutofix = canAutofix
}
if err == nil {
err = aghnet.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
}
if err != nil {
respData.DNS.Status = err.Error()
} else if !reqData.DNS.IP.IsUnspecified() {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
}
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(respData)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding the response: %s", err)
return
}
}
@@ -251,10 +322,11 @@ type applyConfigReqEnt struct {
}
type applyConfigReq struct {
Web applyConfigReqEnt `json:"web"`
DNS applyConfigReqEnt `json:"dns"`
Username string `json:"username"`
Password string `json:"password"`
Username string `json:"username"`
Password string `json:"password"`
Web applyConfigReqEnt `json:"web"`
DNS applyConfigReqEnt `json:"dns"`
}
// copyInstallSettings copies the installation parameters between two
@@ -270,40 +342,58 @@ func copyInstallSettings(dst, src *configuration) {
// shutdownTimeout is the timeout for shutting HTTP server down operation.
const shutdownTimeout = 5 * time.Second
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
func shutdownSrv(ctx context.Context, srv *http.Server) {
defer log.OnPanic("")
if srv == nil {
return
}
defer cancel()
err := srv.Shutdown(ctx)
if err != nil {
log.Error("error while shutting down http server %q: %s", srv.Addr, err)
const msgFmt = "shutting down http server %q: %s"
if errors.Is(err, context.Canceled) {
log.Debug(msgFmt, srv.Addr, err)
} else {
log.Error(msgFmt, srv.Addr, err)
}
}
}
// PasswordMinRunes is the minimum length of user's password in runes.
const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = aghnet.CheckPacketPortAvailable(req.DNS.IP, req.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
if utf8.RuneCountInString(req.Password) < PasswordMinRunes {
aghhttp.Error(
r,
w,
http.StatusUnprocessableEntity,
"password must be at least %d symbols long",
PasswordMinRunes,
)
return
}
err = aghnet.CheckPortAvailable(req.DNS.IP, req.DNS.Port)
err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -317,28 +407,29 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHosts = []net.IP{req.DNS.IP}
config.DNS.Port = req.DNS.Port
// TODO(e.burkov): StartMods() should be put in a separate goroutine at
// the moment we'll allow setting up TLS in the initial configuration or
// the configuration itself will use HTTPS protocol, because the
// underlying functions potentially restart the HTTPS server.
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server.
err = StartMods()
if err != nil {
Context.firstRun = true
copyInstallSettings(config, curConfig)
httpError(w, http.StatusInternalServerError, "%s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
u := User{}
u.Name = req.Username
Context.auth.UserAdd(&u, req.Password)
u := &User{
Name: req.Username,
}
Context.auth.UserAdd(u, req.Password)
err = config.write()
if err != nil {
Context.firstRun = true
copyInstallSettings(config, curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return
}
@@ -349,19 +440,27 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
registerControlHandlers()
returnOK(w)
aghhttp.OK(w)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Method http.(*Server).Shutdown needs to be called in a separate
// goroutine and with its own context, because it waits until all
// requests are handled and will be blocked by it's own caller.
if restartHTTP {
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
go shutdownSrv(ctx, cancel, web.httpServer)
go shutdownSrv(ctx, cancel, web.httpServerBeta)
if !restartHTTP {
return
}
// Method http.(*Server).Shutdown needs to be called in a separate goroutine
// and with its own context, because it waits until all requests are handled
// and will be blocked by it's own caller.
go func(timeout time.Duration) {
defer log.OnPanic("web")
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
shutdownSrv(ctx, web.httpServer)
shutdownSrv(ctx, web.httpServerBeta)
}(shutdownTimeout)
}
// decodeApplyConfigReq decodes the configuration, validates some parameters,
@@ -380,7 +479,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port
if restartHTTP {
err = aghnet.CheckPortAvailable(req.Web.IP, req.Web.Port)
err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port)
if err != nil {
return nil, false, fmt.Errorf(
"checking address %s:%d: %w",
@@ -406,8 +505,8 @@ func (web *Web) registerInstallHandlers() {
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct {
Port int `json:"port"`
IP []net.IP `json:"ip"`
Port int `json:"port"`
Autofix bool `json:"autofix"`
}
@@ -431,24 +530,26 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
reqData := checkConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return
}
nonBetaReqData := checkConfigReq{
Web: checkConfigReqEnt{
Port: reqData.Web.Port,
nonBetaReqData := checkConfReq{
Web: checkConfReqEnt{
IP: reqData.Web.IP[0],
Port: reqData.Web.Port,
Autofix: reqData.Web.Autofix,
},
DNS: checkConfigReqEnt{
Port: reqData.DNS.Port,
DNS: checkConfReqEnt{
IP: reqData.DNS.IP[0],
Port: reqData.DNS.Port,
Autofix: reqData.DNS.Autofix,
},
SetStaticIP: reqData.SetStaticIP,
@@ -458,7 +559,14 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Failed to encode 'check_config' JSON data: %s",
err,
)
return
}
body := nonBetaReqBody.String()
@@ -484,10 +592,11 @@ type applyConfigReqEntBeta struct {
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReq.
type applyConfigReqBeta struct {
Web applyConfigReqEntBeta `json:"web"`
DNS applyConfigReqEntBeta `json:"dns"`
Username string `json:"username"`
Password string `json:"password"`
Username string `json:"username"`
Password string `json:"password"`
Web applyConfigReqEntBeta `json:"web"`
DNS applyConfigReqEntBeta `json:"dns"`
}
// handleInstallConfigureBeta is a substitution of /install/configure handler
@@ -499,12 +608,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
reqData := applyConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
aghhttp.Error(r, w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return
}
@@ -525,7 +636,14 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Failed to encode 'check_config' JSON data: %s",
err,
)
return
}
body := nonBetaReqBody.String()
@@ -541,9 +659,9 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default firstRunData.
type getAddrsResponseBeta struct {
Interfaces []*aghnet.NetInterface `json:"interfaces"`
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces []*aghnet.NetInterface `json:"interfaces"`
}
// handleInstallConfigureBeta is a substitution of /install/get_addresses
@@ -552,13 +670,15 @@ type getAddrsResponseBeta struct {
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallGetAddresses.
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponseBeta{}
data.WebPort = defaultPortHTTP
data.DNSPort = defaultPortDNS
data := getAddrsResponseBeta{
WebPort: defaultPortHTTP,
DNSPort: defaultPortDNS,
}
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
@@ -567,7 +687,14 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return
}
}

View File

@@ -11,6 +11,7 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors"
@@ -43,7 +44,8 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
}
@@ -77,7 +79,15 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if err != nil {
vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb.
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err)
aghhttp.Error(
r,
w,
http.StatusBadGateway,
"Couldn't get version check json from %s: %T %s\n",
vcu,
err,
err,
)
return
}
@@ -87,24 +97,26 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
// handleUpdate performs an update to the latest available version procedure.
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
func handleUpdate(w http.ResponseWriter, r *http.Request) {
if Context.updater.NewVersion() == "" {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
err := Context.updater.Update()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
returnOK(w)
aghhttp.OK(w)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
@@ -36,20 +37,24 @@ func onConfigModified() {
// initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
func initDNSServer() error {
var err error
func initDNSServer() (err error) {
baseDir := Context.getDataDir()
var anonFunc aghnet.IPMutFunc
if config.DNS.AnonymizeClientIP {
anonFunc = querylog.AnonymizeIP
}
anonymizer := aghnet.NewIPMut(anonFunc)
statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.DNS.StatsInterval,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.DNS.StatsInterval,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
}
Context.stats, err = stats.New(statsConf)
if err != nil {
return fmt.Errorf("couldn't initialize statistics module")
return fmt.Errorf("init stats: %w", err)
}
conf := querylog.Config{
@@ -62,6 +67,7 @@ func initDNSServer() error {
Enabled: config.DNS.QueryLogEnabled,
FileEnabled: config.DNS.QueryLogFileEnabled,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
Anonymizer: anonymizer,
}
Context.queryLog = querylog.New(conf)
@@ -76,7 +82,8 @@ func initDNSServer() error {
Stats: Context.stats,
QueryLog: Context.queryLog,
SubnetDetector: Context.subnetDetector,
LocalDomain: config.DNS.LocalDomainName,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
@@ -90,7 +97,8 @@ func initDNSServer() error {
}
Context.clients.dnsServer = Context.dnsServer
dnsConfig, err := generateServerConfig()
var dnsConfig dnsforward.ServerConfig
dnsConfig, err = generateServerConfig()
if err != nil {
closeDNSServer()
@@ -100,6 +108,7 @@ func initDNSServer() error {
err = Context.dnsServer.Prepare(&dnsConfig)
if err != nil {
closeDNSServer()
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
@@ -202,7 +211,6 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
}
newConf.TLSv12Roots = Context.tlsRoots
newConf.TLSCiphers = Context.tlsCiphers
newConf.TLSAllowUnencryptedDoH = tlsConf.AllowUnencryptedDoH
newConf.FilterHandler = applyAdditionalFiltering
@@ -310,7 +318,7 @@ func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filteri
}
}
log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID)
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)

View File

@@ -710,8 +710,8 @@ func enableFilters(async bool) {
}
func enableFiltersLocked(async bool) {
var whiteFilters []filtering.Filter
filters := []filtering.Filter{{
ID: filtering.CustomListID,
Data: []byte(strings.Join(config.UserRules, "\n")),
}}
@@ -725,18 +725,20 @@ func enableFiltersLocked(async bool) {
FilePath: filter.Path(),
})
}
var allowFilters []filtering.Filter
for _, filter := range config.WhitelistFilters {
if !filter.Enabled {
continue
}
whiteFilters = append(whiteFilters, filtering.Filter{
allowFilters = append(allowFilters, filtering.Filter{
ID: filter.ID,
FilePath: filter.Path(),
})
}
if err := Context.dnsFilter.SetFilters(filters, whiteFilters, async); err != nil {
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
log.Debug("enabling filters: %s", err)
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -21,7 +22,7 @@ const testFltsFileName = "1.txt"
func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) {
t.Helper()
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
n, werr := w.Write(*fltContent)
require.NoError(t, werr)
require.Equal(t, len(*fltContent), n)
@@ -34,9 +35,7 @@ func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener)
go func() {
_ = http.Serve(l, h)
}()
t.Cleanup(func() {
require.NoError(t, l.Close())
})
testutil.CleanupAndRequireSuccess(t, l.Close)
return l
}
@@ -100,9 +99,7 @@ func TestFilters(t *testing.T) {
t.Run("refresh_actually", func(t *testing.T) {
fltContent = []byte(`||example.com^`)
t.Cleanup(func() {
fltContent = []byte(content)
})
t.Cleanup(func() { fltContent = []byte(content) })
updateAndAssert(t, require.True, 1)
})

View File

@@ -19,8 +19,10 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -44,20 +46,25 @@ type homeContext struct {
// Modules
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dnsFilter *filtering.DNSFilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
etcHosts *aghnet.EtcHostsContainer // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
updater *updater.Updater
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dnsFilter *filtering.DNSFilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration
// (e.g. /etc/hosts) files.
etcHosts *aghnet.HostsContainer
// hostsWatcher is the watcher to detect changes in the hosts files.
hostsWatcher aghos.FSWatcher
updater *updater.Updater
subnetDetector *aghnet.SubnetDetector
@@ -74,7 +81,6 @@ type homeContext struct {
disableUpdate bool // If set, don't check for updates
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
tlsCiphers []uint16 // list of TLS ciphers to use
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
@@ -139,13 +145,13 @@ func setupContext(args options) {
initConfig()
Context.tlsRoots = LoadSystemRootCAs()
Context.tlsCiphers = InitTLSCiphers()
Context.transport = &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
CipherSuites: aghtls.SaferCipherSuites(),
MinVersion: tls.VersionTLS12,
},
}
Context.client = &http.Client{
@@ -154,14 +160,11 @@ func setupContext(args options) {
}
if !Context.firstRun {
// Do the upgrade if necessary
// Do the upgrade if necessary.
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
fatalOnError(err)
err = parseConfig()
if err != nil {
if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err)
os.Exit(1)
@@ -179,15 +182,15 @@ func setupContext(args options) {
// logIfUnsupported logs a formatted warning if the error is one of the
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherise, it returns err.
// nil. Otherwise, it returns err.
func logIfUnsupported(msg string, err error) (outErr error) {
if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) {
if errors.As(err, new(*aghos.UnsupportedError)) {
log.Debug(msg, err)
} else if err != nil {
return err
return nil
}
return nil
return err
}
// configureOS sets the OS-related configuration.
@@ -230,6 +233,34 @@ func configureOS(conf *configuration) (err error) {
return nil
}
// setupHostsContainer initializes the structures to keep up-to-date the hosts
// provided by the OS.
func setupHostsContainer() (err error) {
Context.hostsWatcher, err = aghos.NewOSWritesWatcher()
if err != nil {
return fmt.Errorf("initing hosts watcher: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(
filtering.SysHostsListID,
aghos.RootDirFS(),
Context.hostsWatcher,
aghnet.DefaultHostsPaths()...,
)
if err != nil {
cerr := Context.hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) && cerr == nil {
log.Info("warning: initing hosts container: %s", err)
return nil
}
return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), cerr)
}
return nil
}
func setupConfig(args options) (err error) {
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
@@ -257,19 +288,41 @@ func setupConfig(args options) (err error) {
})
if !args.noEtcHosts {
Context.etcHosts = &aghnet.EtcHostsContainer{}
Context.etcHosts.Init("")
if err = setupHostsContainer(); err != nil {
return err
}
}
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
config.Clients = nil
if args.bindPort != 0 {
uc := aghalg.UniqChecker{}
addPorts(
uc,
args.bindPort,
config.BetaBindPort,
config.DNS.Port,
)
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
)
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
}
config.BindPort = args.bindPort
}
// override bind host/port from the console
if args.bindHost != nil {
config.BindHost = args.bindHost
}
if args.bindPort != 0 {
config.BindPort = args.bindPort
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
Context.pidFileName = args.pidFile
}
@@ -324,7 +377,7 @@ func fatalOnError(err error) {
}
}
// run performs configurating and starts AdGuard Home.
// run configures and starts AdGuard Home.
func run(args options, clientBuildFS fs.FS) {
var err error
@@ -340,9 +393,9 @@ func run(args options, clientBuildFS fs.FS) {
// Go memory hacks
memoryUsage(args)
// print the first message after logger is configured
// Print the first message after logger is configured.
log.Println(version.Full())
log.Debug("Current working directory is %s", Context.workDir)
log.Debug("current working directory is %s", Context.workDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
}
@@ -424,7 +477,6 @@ func run(args options, clientBuildFS fs.FS) {
fatalOnError(err)
Context.tls.Start()
Context.etcHosts.Start()
go func() {
serr := startDNSServer()
@@ -579,13 +631,13 @@ func configureLogger(args options) {
log.SetLevel(log.DEBUG)
}
// Make sure that we see the microseconds in logs, as networking stuff
// can happen pretty quickly.
// Make sure that we see the microseconds in logs, as networking stuff can
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing else is configured
// Otherwise, we'll simply loose the log output
// When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output.
ls.LogFile = configSyslog
}
@@ -647,7 +699,18 @@ func cleanup(ctx context.Context) {
}
}
Context.etcHosts.Close()
if Context.etcHosts != nil {
// Currently Context.hostsWatcher is only used in Context.etcHosts and
// needs closing only in case of the successful initialization of
// Context.etcHosts.
if err = Context.hostsWatcher.Close(); err != nil {
log.Error("closing hosts watcher: %s", err)
}
if err = Context.etcHosts.Close(); err != nil {
log.Error("closing hosts container: %s", err)
}
}
if Context.tls != nil {
Context.tls.Close()
@@ -722,8 +785,7 @@ func printHTTPAddresses(proto string) {
port = tlsConf.PortHTTPS
}
// TODO(e.burkov): Inspect and perhaps merge with the previous
// condition.
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == schemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)

View File

@@ -6,11 +6,12 @@ import (
"net/http"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
)
// TODO(a.garipov): Get rid of a global variable?
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
var allowedLanguages = stringutil.NewSet(
"be",
"bg",
@@ -20,6 +21,7 @@ var allowedLanguages = stringutil.NewSet(
"en",
"es",
"fa",
"fi",
"fr",
"hr",
"hu",
@@ -41,6 +43,7 @@ var allowedLanguages = stringutil.NewSet(
"sv",
"th",
"tr",
"uk",
"vi",
"zh-cn",
"zh-hk",
@@ -94,5 +97,5 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
}()
onConfigModified()
returnOK(w)
aghhttp.OK(w)
}

View File

@@ -46,7 +46,7 @@ func TestLimitRequestBody(t *testing.T) {
var b []byte
b, *err = io.ReadAll(r.Body)
_, werr := w.Write(b)
require.Nil(t, werr)
require.NoError(t, werr)
})
}

Some files were not shown because too many files have changed in this diff Show More