Merge branch 'master' into 4403-internal-proxy
This commit is contained in:
@@ -1,32 +1,45 @@
|
||||
// Package aghalg contains common generic algorithms and data structures.
|
||||
//
|
||||
// TODO(a.garipov): Update to use type parameters in Go 1.18.
|
||||
// TODO(a.garipov): Move parts of this into golibs.
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// comparable is an alias for interface{}. Values passed as arguments of this
|
||||
// type alias must be comparable.
|
||||
//
|
||||
// TODO(a.garipov): Remove in Go 1.18.
|
||||
type comparable = interface{}
|
||||
// Coalesce returns the first non-zero value. It is named after the function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
|
||||
// value.
|
||||
func Coalesce[T comparable](values ...T) (res T) {
|
||||
var zero T
|
||||
for _, v := range values {
|
||||
if v != zero {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
// UniqChecker allows validating uniqueness of comparable items.
|
||||
type UniqChecker map[comparable]int64
|
||||
//
|
||||
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
|
||||
// Consider ways of making this constraint comparable instead.
|
||||
type UniqChecker[T constraints.Ordered] map[T]int64
|
||||
|
||||
// Add adds a value to the validator. v must not be nil.
|
||||
func (uc UniqChecker) Add(elems ...comparable) {
|
||||
func (uc UniqChecker[T]) Add(elems ...T) {
|
||||
for _, e := range elems {
|
||||
uc[e]++
|
||||
}
|
||||
}
|
||||
|
||||
// Merge returns a checker containing data from both uc and other.
|
||||
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
merged = make(UniqChecker, len(uc)+len(other))
|
||||
func (uc UniqChecker[T]) Merge(other UniqChecker[T]) (merged UniqChecker[T]) {
|
||||
merged = make(UniqChecker[T], len(uc)+len(other))
|
||||
for elem, num := range uc {
|
||||
merged[elem] += num
|
||||
}
|
||||
@@ -39,10 +52,8 @@ func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
}
|
||||
|
||||
// Validate returns an error enumerating all elements that aren't unique.
|
||||
// isBefore is an optional sorting function to make the error message
|
||||
// deterministic.
|
||||
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
|
||||
var dup []comparable
|
||||
func (uc UniqChecker[T]) Validate() (err error) {
|
||||
var dup []T
|
||||
for elem, num := range uc {
|
||||
if num > 1 {
|
||||
dup = append(dup, elem)
|
||||
@@ -53,23 +64,7 @@ func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err
|
||||
return nil
|
||||
}
|
||||
|
||||
if isBefore != nil {
|
||||
sort.Slice(dup, func(i, j int) (less bool) {
|
||||
return isBefore(dup[i], dup[j])
|
||||
})
|
||||
}
|
||||
slices.Sort(dup)
|
||||
|
||||
return fmt.Errorf("duplicated values: %v", dup)
|
||||
}
|
||||
|
||||
// IntIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type int.
|
||||
func IntIsBefore(a, b comparable) (less bool) {
|
||||
return a.(int) < b.(int)
|
||||
}
|
||||
|
||||
// StringIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type string.
|
||||
func StringIsBefore(a, b comparable) (less bool) {
|
||||
return a.(string) < b.(string)
|
||||
}
|
||||
|
||||
67
internal/aghalg/nullbool.go
Normal file
67
internal/aghalg/nullbool.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
// instead of pointers to bool.
|
||||
type NullBool uint8
|
||||
|
||||
// NullBool values
|
||||
const (
|
||||
NBNull NullBool = iota
|
||||
NBTrue
|
||||
NBFalse
|
||||
)
|
||||
|
||||
// String implements the fmt.Stringer interface for NullBool.
|
||||
func (nb NullBool) String() (s string) {
|
||||
switch nb {
|
||||
case NBNull:
|
||||
return "null"
|
||||
case NBTrue:
|
||||
return "true"
|
||||
case NBFalse:
|
||||
return "false"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("!invalid NullBool %d", uint8(nb))
|
||||
}
|
||||
|
||||
// BoolToNullBool converts a bool into a NullBool.
|
||||
func BoolToNullBool(cond bool) (nb NullBool) {
|
||||
if cond {
|
||||
return NBTrue
|
||||
}
|
||||
|
||||
return NBFalse
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ json.Marshaler = NBNull
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for NullBool.
|
||||
func (nb NullBool) MarshalJSON() (b []byte, err error) {
|
||||
return []byte(nb.String()), nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ json.Unmarshaler = (*NullBool)(nil)
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *NullBool.
|
||||
func (nb *NullBool) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
|
||||
*nb = NBNull
|
||||
} else if bytes.Equal(b, []byte("true")) {
|
||||
*nb = NBTrue
|
||||
} else if bytes.Equal(b, []byte("false")) {
|
||||
*nb = NBFalse
|
||||
} else {
|
||||
return fmt.Errorf("unmarshalling json data into aghalg.NullBool: bad value %q", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
113
internal/aghalg/nullbool_test.go
Normal file
113
internal/aghalg/nullbool_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package aghalg_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_MarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
want []byte
|
||||
in aghalg.NullBool
|
||||
}{{
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
want: []byte("null"),
|
||||
in: aghalg.NBNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
want: []byte("true"),
|
||||
in: aghalg.NBTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
want: []byte("false"),
|
||||
in: aghalg.NBFalse,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.in.MarshalJSON()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
in := &struct {
|
||||
A aghalg.NullBool
|
||||
}{
|
||||
A: aghalg.NBTrue,
|
||||
}
|
||||
|
||||
got, err := json.Marshal(in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte(`{"A":true}`), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want aghalg.NullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
data: []byte{},
|
||||
want: aghalg.NBNull,
|
||||
}, {
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: aghalg.NBNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
data: []byte("true"),
|
||||
want: aghalg.NBTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
data: []byte("false"),
|
||||
want: aghalg.NBFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErrMsg: `unmarshalling json data into aghalg.NullBool: bad value "invalid"`,
|
||||
data: []byte("invalid"),
|
||||
want: aghalg.NBNull,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got aghalg.NullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
want := aghalg.NBTrue
|
||||
var got struct {
|
||||
A aghalg.NullBool
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(`{"A":true}`), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, got.A)
|
||||
})
|
||||
}
|
||||
@@ -9,6 +9,12 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Get rid of it.
|
||||
type RegisterFunc func(method, url string, handler http.HandlerFunc)
|
||||
|
||||
// OK responds with word OK.
|
||||
func OK(w http.ResponseWriter) {
|
||||
if _, err := io.WriteString(w, "OK\n"); err != nil {
|
||||
@@ -17,7 +23,7 @@ func OK(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
|
||||
@@ -2,13 +2,11 @@ package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
@@ -27,15 +25,8 @@ type ARPDB interface {
|
||||
}
|
||||
|
||||
// NewARPDB returns the ARPDB properly initialized for the OS.
|
||||
func NewARPDB() (arp ARPDB, err error) {
|
||||
arp = newARPDB()
|
||||
|
||||
err = arp.Refresh()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("arpdb initial refresh: %w", err)
|
||||
}
|
||||
|
||||
return arp, nil
|
||||
func NewARPDB() (arp ARPDB) {
|
||||
return newARPDB()
|
||||
}
|
||||
|
||||
// Empty ARPDB implementation
|
||||
@@ -123,50 +114,33 @@ func (ns *neighs) reset(with []Neighbor) {
|
||||
// of Neighbors.
|
||||
type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor)
|
||||
|
||||
// runCmdFunc is the function that runs some command and returns its output
|
||||
// wrapped to be a io.Reader.
|
||||
type runCmdFunc func() (r io.Reader, err error)
|
||||
|
||||
// cmdARPDB is the implementation of the ARPDB that uses command line to
|
||||
// retrieve data.
|
||||
type cmdARPDB struct {
|
||||
parse parseNeighsFunc
|
||||
runcmd runCmdFunc
|
||||
ns *neighs
|
||||
parse parseNeighsFunc
|
||||
ns *neighs
|
||||
cmd string
|
||||
args []string
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ARPDB = (*cmdARPDB)(nil)
|
||||
|
||||
// runCmd runs the cmd with it's args and returns the result wrapped to be an
|
||||
// io.Reader. The error is returned either if the exit code retured by command
|
||||
// not equals 0 or the execution itself failed.
|
||||
func runCmd(cmd string, args ...string) (r io.Reader, err error) {
|
||||
var code int
|
||||
var out string
|
||||
code, out, err = aghos.RunCommand(cmd, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if code != 0 {
|
||||
return nil, fmt.Errorf("unexpected exit code %d", code)
|
||||
}
|
||||
|
||||
return strings.NewReader(out), nil
|
||||
}
|
||||
|
||||
// Refresh implements the ARPDB interface for *cmdARPDB.
|
||||
func (arp *cmdARPDB) Refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
|
||||
|
||||
var r io.Reader
|
||||
r, err = arp.runcmd()
|
||||
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running command: %w", err)
|
||||
} else if code != 0 {
|
||||
return fmt.Errorf("running command: unexpected exit code %d", code)
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(r)
|
||||
sc := bufio.NewScanner(bytes.NewReader(out))
|
||||
ns := arp.parse(sc, arp.ns.len())
|
||||
if err = sc.Err(); err != nil {
|
||||
// TODO(e.burkov): This error seems unreachable. Investigate.
|
||||
return fmt.Errorf("scanning the output: %w", err)
|
||||
}
|
||||
|
||||
@@ -187,8 +161,7 @@ func (arp *cmdARPDB) Neighbors() (ns []Neighbor) {
|
||||
type arpdbs struct {
|
||||
// arps is the set of ARPDB implementations to range through.
|
||||
arps []ARPDB
|
||||
// last is the last succeeded ARPDB index.
|
||||
last int
|
||||
neighs
|
||||
}
|
||||
|
||||
// newARPDBs returns a properly initialized *arpdbs. It begins refreshing from
|
||||
@@ -196,7 +169,10 @@ type arpdbs struct {
|
||||
func newARPDBs(arps ...ARPDB) (arp *arpdbs) {
|
||||
return &arpdbs{
|
||||
arps: arps,
|
||||
last: 0,
|
||||
neighs: neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,20 +182,18 @@ var _ ARPDB = (*arpdbs)(nil)
|
||||
// Refresh implements the ARPDB interface for *arpdbs.
|
||||
func (arp *arpdbs) Refresh() (err error) {
|
||||
var errs []error
|
||||
l := len(arp.arps)
|
||||
// Start from the last succeeded implementation.
|
||||
for i := 0; i < l; i++ {
|
||||
cur := (arp.last + i) % l
|
||||
err = arp.arps[cur].Refresh()
|
||||
if err == nil {
|
||||
// The succeeded implementation found so update the last succeeded
|
||||
// index.
|
||||
arp.last = cur
|
||||
|
||||
return nil
|
||||
for _, a := range arp.arps {
|
||||
err = a.Refresh()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
errs = append(errs, err)
|
||||
arp.reset(a.Neighbors())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
@@ -230,10 +204,8 @@ func (arp *arpdbs) Refresh() (err error) {
|
||||
}
|
||||
|
||||
// Neighbors implements the ARPDB interface for *arpdbs.
|
||||
//
|
||||
// TODO(e.burkov): Think of a way to avoid cloning the slice twice.
|
||||
func (arp *arpdbs) Neighbors() (ns []Neighbor) {
|
||||
if l := len(arp.arps); l > 0 && arp.last < l {
|
||||
return arp.arps[arp.last].Neighbors()
|
||||
}
|
||||
|
||||
return nil
|
||||
return arp.clone()
|
||||
}
|
||||
|
||||
@@ -13,18 +13,24 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
func newARPDB() *cmdARPDB {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
parse: parseArpA,
|
||||
runcmd: rcArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
cmd: "arp",
|
||||
// Use -n flag to avoid resolving the hostnames of the neighbors. By
|
||||
// default ARP attempts to resolve the hostnames via DNS. See man 8
|
||||
// arp.
|
||||
//
|
||||
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
|
||||
args: []string{"-a", "-n"},
|
||||
}
|
||||
}
|
||||
|
||||
// parseArpA parses the output of the "arp -a" command on macOS and FreeBSD.
|
||||
// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD.
|
||||
// The expected input format:
|
||||
//
|
||||
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
|
||||
|
||||
@@ -8,8 +8,12 @@ import (
|
||||
)
|
||||
|
||||
const arpAOutput = `
|
||||
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
|
||||
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
|
||||
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
|
||||
hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
|
||||
hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet]
|
||||
? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet]
|
||||
`
|
||||
|
||||
var wantNeighs = []Neighbor{{
|
||||
@@ -20,4 +24,8 @@ var wantNeighs = []Neighbor{{
|
||||
Name: "hostname.two",
|
||||
IP: net.ParseIP("::ffff:ffff"),
|
||||
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
|
||||
}, {
|
||||
Name: "",
|
||||
IP: net.ParseIP("::1234"),
|
||||
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
}}
|
||||
|
||||
@@ -6,7 +6,6 @@ package aghnet
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -34,11 +33,30 @@ func newARPDB() (arp *arpdbs) {
|
||||
|
||||
return newARPDBs(
|
||||
// Try /proc/net/arp first.
|
||||
&fsysARPDB{ns: ns, fsys: aghos.RootDirFS(), filename: "proc/net/arp"},
|
||||
// Try "arp -a" then.
|
||||
&cmdARPDB{parse: parseF, runcmd: rcArpA, ns: ns},
|
||||
// Try "ip neigh" finally.
|
||||
&cmdARPDB{parse: parseIPNeigh, runcmd: rcIPNeigh, ns: ns},
|
||||
&fsysARPDB{
|
||||
ns: ns,
|
||||
fsys: rootDirFS,
|
||||
filename: "proc/net/arp",
|
||||
},
|
||||
// Then, try "arp -a -n".
|
||||
&cmdARPDB{
|
||||
parse: parseF,
|
||||
ns: ns,
|
||||
cmd: "arp",
|
||||
// Use -n flag to avoid resolving the hostnames of the neighbors.
|
||||
// By default ARP attempts to resolve the hostnames via DNS. See
|
||||
// man 8 arp.
|
||||
//
|
||||
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
|
||||
args: []string{"-a", "-n"},
|
||||
},
|
||||
// Finally, try "ip neigh".
|
||||
&cmdARPDB{
|
||||
parse: parseIPNeigh,
|
||||
ns: ns,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -96,11 +114,11 @@ func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
|
||||
return arp.ns.clone()
|
||||
}
|
||||
|
||||
// parseArpAWrt parses the output of the "arp -a" command on OpenWrt. The
|
||||
// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The
|
||||
// expected input format:
|
||||
//
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
|
||||
//
|
||||
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
if !sc.Scan() {
|
||||
@@ -140,8 +158,8 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
return ns
|
||||
}
|
||||
|
||||
// parseArpA parses the output of the "arp -a" command on Linux. The expected
|
||||
// input format:
|
||||
// parseArpA parses the output of the "arp -a -n" command on Linux. The
|
||||
// expected input format:
|
||||
//
|
||||
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
|
||||
//
|
||||
@@ -187,11 +205,6 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
return ns
|
||||
}
|
||||
|
||||
// rcIPNeigh runs "ip neigh".
|
||||
func rcIPNeigh() (r io.Reader, err error) {
|
||||
return runCmd("ip", "neigh")
|
||||
}
|
||||
|
||||
// parseIPNeigh parses the output of the "ip neigh" command on Linux. The
|
||||
// expected input format:
|
||||
//
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -16,14 +15,21 @@ import (
|
||||
|
||||
const arpAOutputWrt = `
|
||||
IP address HW type Flags HW address Mask Device
|
||||
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
|
||||
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
|
||||
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
|
||||
::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan`
|
||||
|
||||
const arpAOutput = `
|
||||
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
|
||||
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
|
||||
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
|
||||
? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
|
||||
? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]`
|
||||
|
||||
const ipNeighOutput = `
|
||||
1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY
|
||||
1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY
|
||||
192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY
|
||||
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
|
||||
|
||||
@@ -36,6 +42,8 @@ var wantNeighs = []Neighbor{{
|
||||
}}
|
||||
|
||||
func TestFSysARPDB(t *testing.T) {
|
||||
require.NoError(t, fstest.TestFS(testdata, "proc_net_arp"))
|
||||
|
||||
a := &fsysARPDB{
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
@@ -52,33 +60,43 @@ func TestFSysARPDB(t *testing.T) {
|
||||
assert.Equal(t, wantNeighs, ns)
|
||||
}
|
||||
|
||||
func TestCmdARPDB_arpawrt(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
parse: parseArpAWrt,
|
||||
runcmd: func() (r io.Reader, err error) { return strings.NewReader(arpAOutputWrt), nil },
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
func TestCmdARPDB_linux(t *testing.T) {
|
||||
sh := mapShell{
|
||||
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
|
||||
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
|
||||
}
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
t.Run("wrt", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
parse: parseArpAWrt,
|
||||
cmd: "arp",
|
||||
args: []string{"-a"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
}
|
||||
|
||||
func TestCmdARPDB_ipneigh(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
parse: parseIPNeigh,
|
||||
runcmd: func() (r io.Reader, err error) { return strings.NewReader(ipNeighOutput), nil },
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
})
|
||||
|
||||
t.Run("ip_neigh", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
parse: parseIPNeigh,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,19 +12,25 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func newARPDB() *cmdARPDB {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
runcmd: rcArpA,
|
||||
parse: parseArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
cmd: "arp",
|
||||
// Use -n flag to avoid resolving the hostnames of the neighbors. By
|
||||
// default ARP attempts to resolve the hostnames via DNS. See man 8
|
||||
// arp.
|
||||
//
|
||||
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
|
||||
args: []string{"-a", "-n"},
|
||||
}
|
||||
}
|
||||
|
||||
// parseArpA parses the output of the "arp -a" command on OpenBSD. The expected
|
||||
// input format:
|
||||
// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The
|
||||
// expected input format:
|
||||
//
|
||||
// Host Ethernet Address Netif Expire Flags
|
||||
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
const arpAOutput = `
|
||||
Host Ethernet Address Netif Expire Flags
|
||||
1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent
|
||||
1.2.3.4 12:34:56:78:910 em0 permanent
|
||||
192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s
|
||||
::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l
|
||||
`
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -13,6 +11,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewARPDB(t *testing.T) {
|
||||
var a ARPDB
|
||||
require.NotPanics(t, func() { a = NewARPDB() })
|
||||
|
||||
assert.NotNil(t, a)
|
||||
}
|
||||
|
||||
// TestARPDB is the mock implementation of ARPDB to use in tests.
|
||||
type TestARPDB struct {
|
||||
OnRefresh func() (err error)
|
||||
@@ -125,11 +130,11 @@ func TestARPDBS(t *testing.T) {
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
assert.NotEmpty(t, a.Neighbors())
|
||||
|
||||
// Only the last succeeded ARPDB should be used.
|
||||
// Unstable ARPDB should refresh successfully again.
|
||||
err = a.Refresh()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, succRefrCount)
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
assert.NotEmpty(t, a.Neighbors())
|
||||
})
|
||||
|
||||
@@ -143,6 +148,7 @@ func TestARPDBS(t *testing.T) {
|
||||
|
||||
func TestCmdARPDB_arpa(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
cmd: "cmd",
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
@@ -151,7 +157,8 @@ func TestCmdARPDB_arpa(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("arp_a", func(t *testing.T) {
|
||||
a.runcmd = func() (r io.Reader, err error) { return strings.NewReader(arpAOutput), nil }
|
||||
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
@@ -160,9 +167,50 @@ func TestCmdARPDB_arpa(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("runcmd_error", func(t *testing.T) {
|
||||
a.runcmd = func() (r io.Reader, err error) { return nil, errors.Error("can't run") }
|
||||
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
err := a.Refresh()
|
||||
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
|
||||
})
|
||||
|
||||
t.Run("bad_code", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 1, "", nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
err := a.Refresh()
|
||||
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
|
||||
})
|
||||
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 0, "", nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
err := a.Refresh()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, a.Neighbors())
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmptyARPDB(t *testing.T) {
|
||||
a := EmptyARPDB{}
|
||||
|
||||
t.Run("refresh", func(t *testing.T) {
|
||||
var err error
|
||||
require.NotPanics(t, func() {
|
||||
err = a.Refresh()
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("neighbors", func(t *testing.T) {
|
||||
var ns []Neighbor
|
||||
require.NotPanics(t, func() {
|
||||
ns = a.Neighbors()
|
||||
})
|
||||
|
||||
assert.Empty(t, ns)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// rcArpA runs "arp -a".
|
||||
func rcArpA() (r io.Reader, err error) {
|
||||
return runCmd("arp", "-a")
|
||||
}
|
||||
@@ -5,28 +5,23 @@ package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func newARPDB() *cmdARPDB {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
runcmd: rcArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
parse: parseArpA,
|
||||
cmd: "arp",
|
||||
args: []string{"/a"},
|
||||
}
|
||||
}
|
||||
|
||||
// rcArpA runs "arp /a".
|
||||
func rcArpA() (r io.Reader, err error) {
|
||||
return runCmd("arp", "/a")
|
||||
}
|
||||
|
||||
// parseArpA parses the output of the "arp /a" command on Windows. The expected
|
||||
// input format (the first line is empty):
|
||||
//
|
||||
|
||||
@@ -156,7 +156,7 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
|
||||
b := make([]byte, 1500)
|
||||
n, _, err := c.ReadFrom(b)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
log.Debug("dhcpv4: didn't receive dhcp response")
|
||||
|
||||
return false, false, nil
|
||||
@@ -176,20 +176,21 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
|
||||
|
||||
log.Debug("dhcpv4: received message from server: %s", response.Summary())
|
||||
|
||||
if !(response.OpCode == dhcpv4.OpcodeBootReply &&
|
||||
response.HWType == iana.HWTypeEthernet &&
|
||||
bytes.Equal(response.ClientHWAddr, iface.HardwareAddr) &&
|
||||
bytes.Equal(response.TransactionID[:], req.TransactionID[:]) &&
|
||||
response.Options.Has(dhcpv4.OptionDHCPMessageType)) {
|
||||
|
||||
log.Debug("dhcpv4: received message from server doesn't match our request")
|
||||
switch {
|
||||
case
|
||||
response.OpCode != dhcpv4.OpcodeBootReply,
|
||||
response.HWType != iana.HWTypeEthernet,
|
||||
!bytes.Equal(response.ClientHWAddr, iface.HardwareAddr),
|
||||
response.TransactionID != req.TransactionID,
|
||||
!response.Options.Has(dhcpv4.OptionDHCPMessageType):
|
||||
log.Debug("dhcpv4: received response doesn't match the request")
|
||||
|
||||
return false, true, nil
|
||||
default:
|
||||
log.Tracef("dhcpv4: the packet is from an active dhcp server")
|
||||
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
log.Tracef("dhcpv4: the packet is from an active dhcp server")
|
||||
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// checkOtherDHCPv6 sends a DHCP request to the specified network interface, and
|
||||
@@ -275,7 +276,7 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
|
||||
|
||||
n, _, err := c.ReadFrom(b)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
log.Debug("dhcpv6: didn't receive dhcp response")
|
||||
|
||||
return false, false, nil
|
||||
@@ -318,15 +319,3 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
|
||||
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// isTimeout returns true if err is an operation timeout error from net package.
|
||||
//
|
||||
// TODO(e.burkov): Consider moving into netutil.
|
||||
func isTimeout(err error) (ok bool) {
|
||||
var operr *net.OpError
|
||||
if errors.As(err, &operr) {
|
||||
return operr.Timeout()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ const (
|
||||
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
|
||||
)
|
||||
|
||||
// generateIPv4Hostname generates the hostname for specific IP version.
|
||||
// generateIPv4Hostname generates the hostname by IP address version 4.
|
||||
func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv4HostnameMaxLen)
|
||||
for i, part := range ipv4 {
|
||||
@@ -24,7 +24,7 @@ func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
|
||||
return string(hnData)
|
||||
}
|
||||
|
||||
// generateIPv6Hostname generates the hostname for specific IP version.
|
||||
// generateIPv6Hostname generates the hostname by IP address version 6.
|
||||
func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv6HostnameMaxLen)
|
||||
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
|
||||
@@ -51,12 +51,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
//
|
||||
// ff80-f076-0000-0000-0000-0000-0000-0010
|
||||
//
|
||||
// ip must be either an IPv4 or an IPv6.
|
||||
func GenerateHostname(ip net.IP) (hostname string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return generateIPv4Hostname(ipv4)
|
||||
} else if ipv6 := ip.To16(); ipv6 != nil {
|
||||
return generateIPv6Hostname(ipv6)
|
||||
}
|
||||
|
||||
return ""
|
||||
return generateIPv6Hostname(ip)
|
||||
}
|
||||
|
||||
@@ -8,41 +8,57 @@ import (
|
||||
)
|
||||
|
||||
func TestGenerateHostName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "good_ipv4",
|
||||
want: "127-0-0-1",
|
||||
ip: net.IP{127, 0, 0, 1},
|
||||
}, {
|
||||
name: "bad_ipv4",
|
||||
want: "",
|
||||
ip: net.IP{127, 0, 0, 1, 0},
|
||||
}, {
|
||||
name: "good_ipv6",
|
||||
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
|
||||
ip: net.ParseIP("fe00::1"),
|
||||
}, {
|
||||
name: "bad_ipv6",
|
||||
want: "",
|
||||
ip: net.IP{
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff,
|
||||
},
|
||||
}, {
|
||||
name: "nil",
|
||||
want: "",
|
||||
ip: nil,
|
||||
}}
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "good_ipv4",
|
||||
want: "127-0-0-1",
|
||||
ip: net.IP{127, 0, 0, 1},
|
||||
}, {
|
||||
name: "good_ipv6",
|
||||
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
|
||||
ip: net.ParseIP("fe00::1"),
|
||||
}, {
|
||||
name: "4to6",
|
||||
want: "1-2-3-4",
|
||||
ip: net.ParseIP("::ffff:1.2.3.4"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hostname := GenerateHostname(tc.ip)
|
||||
assert.Equal(t, tc.want, hostname)
|
||||
})
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hostname := GenerateHostname(tc.ip)
|
||||
assert.Equal(t, tc.want, hostname)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "bad_ipv4",
|
||||
ip: net.IP{127, 0, 0, 1, 0},
|
||||
}, {
|
||||
name: "bad_ipv6",
|
||||
ip: net.IP{
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff,
|
||||
},
|
||||
}, {
|
||||
name: "nil",
|
||||
ip: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Panics(t, func() { GenerateHostname(tc.ip) })
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ type requestMatcher struct {
|
||||
//
|
||||
// It's safe for concurrent use.
|
||||
func (rm *requestMatcher) MatchRequest(
|
||||
req urlfilter.DNSRequest,
|
||||
req *urlfilter.DNSRequest,
|
||||
) (res *urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
@@ -198,7 +198,7 @@ func (hc *HostsContainer) Close() (err error) {
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent. The receivable
|
||||
// map's values are guaranteed to be of type of *stringutil.Set.
|
||||
// map's values are guaranteed to be of type of *HostsRecord.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||
return hc.updates
|
||||
}
|
||||
@@ -290,7 +290,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
|
||||
continue
|
||||
}
|
||||
|
||||
hp.addPairs(ip, hosts)
|
||||
hp.addRecord(ip, hosts)
|
||||
}
|
||||
|
||||
return nil, true, s.Err()
|
||||
@@ -335,41 +335,68 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
return ip, hosts
|
||||
}
|
||||
|
||||
// addPair puts the pair of ip and host to the rules builder if needed. For
|
||||
// each ip the first member of hosts will become the main one.
|
||||
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
|
||||
// HostsRecord represents a single hosts file record.
|
||||
type HostsRecord struct {
|
||||
Aliases *stringutil.Set
|
||||
Canonical string
|
||||
}
|
||||
|
||||
// Equal returns true if all fields of rec are equal to field in other or they
|
||||
// both are nil.
|
||||
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
|
||||
if rec == nil {
|
||||
return other == nil
|
||||
}
|
||||
|
||||
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
|
||||
}
|
||||
|
||||
// addRecord puts the record for the IP address to the rules builder if needed.
|
||||
// The first host is considered to be the canonical name for the IP address.
|
||||
// hosts must have at least one name.
|
||||
func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
|
||||
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
|
||||
var rec *HostsRecord
|
||||
v, ok := hp.table.Get(ip)
|
||||
if !ok {
|
||||
// This ip is added at the first time.
|
||||
v = stringutil.NewSet()
|
||||
hp.table.Set(ip, v)
|
||||
rec = &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
}
|
||||
|
||||
rec.Canonical, hosts = hosts[0], hosts[1:]
|
||||
hp.addRules(ip, rec.Canonical, line)
|
||||
hp.table.Set(ip, rec)
|
||||
} else {
|
||||
rec, ok = v.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
if !ok {
|
||||
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
for _, h := range hosts {
|
||||
if set.Has(h) {
|
||||
for _, host := range hosts {
|
||||
if rec.Canonical == host || rec.Aliases.Has(host) {
|
||||
continue
|
||||
}
|
||||
|
||||
set.Add(h)
|
||||
rec.Aliases.Add(host)
|
||||
|
||||
rule, rulePtr := hp.writeRules(h, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = processed, processed
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
|
||||
hp.addRules(ip, host, line)
|
||||
}
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the
|
||||
// host-ip pair into internal builders.
|
||||
// addRules adds rules and rule translations for the line.
|
||||
func (hp *hostsParser) addRules(ip net.IP, host, line string) {
|
||||
rule, rulePtr := hp.writeRules(host, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = line, line
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
|
||||
// pair into internal builders.
|
||||
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
|
||||
arpa, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
@@ -417,6 +444,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
|
||||
}
|
||||
|
||||
// equalSet returns true if the internal hosts table just parsed equals target.
|
||||
// target's values must be of type *HostsRecord.
|
||||
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
if target == nil {
|
||||
// hp.table shouldn't appear nil since it's initialized on each refresh.
|
||||
@@ -427,22 +455,35 @@ func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
|
||||
// ok is set to true if the target doesn't contain ip or if the
|
||||
// appropriate hosts set isn't equal to the checked one.
|
||||
if a, hasIP := target.Get(ip); !hasIP {
|
||||
ok = true
|
||||
} else if hosts, aok := a.(*stringutil.Set); aok {
|
||||
ok = !hosts.Equal(b.(*stringutil.Set))
|
||||
hp.table.Range(func(ip net.IP, recVal any) (cont bool) {
|
||||
var targetVal any
|
||||
targetVal, ok = target.Get(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Continue only if maps has no discrepancies.
|
||||
return !ok
|
||||
var rec *HostsRecord
|
||||
rec, ok = recVal.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var targetRec *HostsRecord
|
||||
targetRec, ok = targetVal.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
ok = rec.Equal(targetRec)
|
||||
|
||||
return ok
|
||||
})
|
||||
|
||||
// Return true if every value from the IP map has no discrepancies with the
|
||||
// appropriate one from the target.
|
||||
return !ok
|
||||
return ok
|
||||
}
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -159,31 +160,47 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
|
||||
upd, ok := <-hc.Upd()
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
||||
t.Helper()
|
||||
|
||||
var ok bool
|
||||
var upd *netutil.IPMap
|
||||
select {
|
||||
case upd, ok = <-hc.Upd():
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("did not receive after 1s")
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
|
||||
v, ok := upd.Get(ip)
|
||||
require.True(t, ok)
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, (*HostsRecord)(nil), v)
|
||||
|
||||
assert.True(t, set.Equal(wantHosts))
|
||||
rec, _ := v.(*HostsRecord)
|
||||
require.NotNil(t, rec)
|
||||
|
||||
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, stringutil.NewSet("hostname"))
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("second_refresh", func(t *testing.T) {
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
|
||||
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet("alias"),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("double_refresh", func(t *testing.T) {
|
||||
@@ -193,7 +210,7 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
|
||||
// Require the changes are written.
|
||||
require.Eventually(t, func() bool {
|
||||
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||
res, ok := hc.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: "hostname",
|
||||
DNSType: dns.TypeA,
|
||||
})
|
||||
@@ -207,7 +224,7 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
|
||||
// Require the changes are written.
|
||||
require.Eventually(t, func() bool {
|
||||
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
|
||||
res, ok := hc.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: "hostname",
|
||||
DNSType: dns.TypeA,
|
||||
})
|
||||
@@ -286,6 +303,8 @@ func TestHostsContainer_Translate(t *testing.T) {
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||
|
||||
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
@@ -358,11 +377,18 @@ func TestHostsContainer_Translate(t *testing.T) {
|
||||
func TestHostsContainer(t *testing.T) {
|
||||
const listID = 1234
|
||||
|
||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||
|
||||
testCases := []struct {
|
||||
want []*rules.DNSRewrite
|
||||
req *urlfilter.DNSRequest
|
||||
name string
|
||||
req urlfilter.DNSRequest
|
||||
want []*rules.DNSRewrite
|
||||
}{{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "simple",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 1),
|
||||
@@ -372,27 +398,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::1"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "simple",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "hello_alias",
|
||||
req: urlfilter.DNSRequest{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
@@ -402,26 +413,41 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "other_line_alias",
|
||||
req: urlfilter.DNSRequest{
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world.again",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "other_line_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_subdomain",
|
||||
req: urlfilter.DNSRequest{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_alias_subdomain",
|
||||
req: urlfilter.DNSRequest{
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "hello_alias_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "lots_of_aliases",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -431,37 +457,37 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::2"),
|
||||
}},
|
||||
name: "lots_of_aliases",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
name: "reverse",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypePTR,
|
||||
Value: "simplehost.",
|
||||
}},
|
||||
name: "reverse",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "non-existing",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "nonexisting",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "nonexistent.example",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "non-existing",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
want: nil,
|
||||
name: "bad_type",
|
||||
req: urlfilter.DNSRequest{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypeSRV,
|
||||
},
|
||||
name: "bad_type",
|
||||
want: nil,
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "issue_4216_4_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -471,12 +497,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::42"),
|
||||
}},
|
||||
name: "issue_4216_4_6",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "issue_4216_4",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -486,12 +512,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(1, 3, 5, 7),
|
||||
}},
|
||||
name: "issue_4216_4",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
name: "issue_4216_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
@@ -501,11 +527,6 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::31"),
|
||||
}},
|
||||
name: "issue_4216_6",
|
||||
req: urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
}}
|
||||
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
|
||||
@@ -2,19 +2,31 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// Variables and functions to substitute in tests.
|
||||
var (
|
||||
// aghosRunCommand is the function to run shell commands.
|
||||
aghosRunCommand = aghos.RunCommand
|
||||
|
||||
// netInterfaces is the function to get the available network interfaces.
|
||||
netInterfaceAddrs = net.InterfaceAddrs
|
||||
|
||||
// rootDirFS is the filesystem pointing to the root directory.
|
||||
rootDirFS = aghos.RootDirFS()
|
||||
)
|
||||
|
||||
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
|
||||
// the IP being static is available.
|
||||
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
|
||||
@@ -32,39 +44,29 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
|
||||
}
|
||||
|
||||
// GatewayIP returns IP address of interface's gateway.
|
||||
func GatewayIP(ifaceName string) net.IP {
|
||||
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
d, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
//
|
||||
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
|
||||
// way since not every machine has the software installed.
|
||||
func GatewayIP(ifaceName string) (ip net.IP) {
|
||||
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
|
||||
if err != nil {
|
||||
log.Debug("%s", err)
|
||||
|
||||
return nil
|
||||
} else if code != 0 {
|
||||
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
fields := strings.Fields(string(d))
|
||||
fields := bytes.Fields(out)
|
||||
// The meaningful "ip route" command output should contain the word
|
||||
// "default" at first field and default gateway IP address at third field.
|
||||
if len(fields) < 3 || fields[0] != "default" {
|
||||
if len(fields) < 3 || string(fields[0]) != "default" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return net.ParseIP(fields[2])
|
||||
}
|
||||
|
||||
// CanBindPort checks if we can bind to the given port.
|
||||
func CanBindPort(port int) (can bool, err error) {
|
||||
var addr *net.TCPAddr
|
||||
addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var listener *net.TCPListener
|
||||
listener, err = net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true, nil
|
||||
return net.ParseIP(string(fields[2]))
|
||||
}
|
||||
|
||||
// CanBindPrivilegedPorts checks if current process can bind to privileged
|
||||
@@ -99,19 +101,19 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only
|
||||
// we do not return link-local addresses here
|
||||
func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
|
||||
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
|
||||
// WEB only we do not return link-local addresses here.
|
||||
//
|
||||
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
|
||||
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
|
||||
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
|
||||
}
|
||||
if len(ifaces) == 0 {
|
||||
} else if len(ifaces) == 0 {
|
||||
return nil, errors.Error("couldn't find any legible interface")
|
||||
}
|
||||
|
||||
var netInterfaces []*NetInterface
|
||||
|
||||
for _, iface := range ifaces {
|
||||
var addrs []net.Addr
|
||||
addrs, err = iface.Addrs()
|
||||
@@ -131,27 +133,34 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// Should be net.IPNet, this is weird.
|
||||
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
|
||||
}
|
||||
|
||||
// Ignore link-local.
|
||||
if ipNet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
|
||||
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
|
||||
netIface.Subnets = append(netIface.Subnets, ipNet)
|
||||
}
|
||||
|
||||
// Discard interfaces with no addresses.
|
||||
if len(netIface.Addresses) != 0 {
|
||||
netInterfaces = append(netInterfaces, netIface)
|
||||
netIfaces = append(netIfaces, netIface)
|
||||
}
|
||||
}
|
||||
|
||||
return netInterfaces, nil
|
||||
return netIfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceByIP returns the name of interface containing provided ip.
|
||||
func GetInterfaceByIP(ip net.IP) string {
|
||||
// InterfaceByIP returns the name of the interface bound to ip.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This function is technically incorrect, since one
|
||||
// IP address can be shared by multiple interfaces in some configurations.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func InterfaceByIP(ip net.IP) (ifaceName string) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -170,6 +179,8 @@ func GetInterfaceByIP(ip net.IP) string {
|
||||
|
||||
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
|
||||
// the search fails.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func GetSubnet(ifaceName string) *net.IPNet {
|
||||
netIfaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
@@ -220,29 +231,21 @@ func IsAddrInUse(err error) (ok bool) {
|
||||
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
|
||||
// addresses without port number.
|
||||
func CollectAllIfacesAddrs() (addrs []string, err error) {
|
||||
var ifaces []net.Interface
|
||||
ifaces, err = net.Interfaces()
|
||||
var ifaceAddrs []net.Addr
|
||||
ifaceAddrs, err = netInterfaceAddrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting network interfaces: %w", err)
|
||||
return nil, fmt.Errorf("getting interfaces addresses: %w", err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
var ifaceAddrs []net.Addr
|
||||
ifaceAddrs, err = iface.Addrs()
|
||||
for _, addr := range ifaceAddrs {
|
||||
cidr := addr.String()
|
||||
var ip net.IP
|
||||
ip, _, err = net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err)
|
||||
return nil, fmt.Errorf("parsing cidr: %w", err)
|
||||
}
|
||||
|
||||
for _, addr := range ifaceAddrs {
|
||||
cidr := addr.String()
|
||||
var ip net.IP
|
||||
ip, _, err = net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing cidr: %w", err)
|
||||
}
|
||||
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
|
||||
10
internal/aghnet/net_bsd.go
Normal file
10
internal/aghnet/net_bsd.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build darwin || freebsd || openbsd
|
||||
// +build darwin freebsd openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
@@ -4,10 +4,11 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -23,11 +24,7 @@ type hardwarePortInfo struct {
|
||||
static bool
|
||||
}
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (bool, error) {
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -36,9 +33,10 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) {
|
||||
return portInfo.static, nil
|
||||
}
|
||||
|
||||
// getCurrentHardwarePortInfo gets information for the specified network interface.
|
||||
// getCurrentHardwarePortInfo gets information for the specified network
|
||||
// interface.
|
||||
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||
// First of all we should find hardware port name
|
||||
// First of all we should find hardware port name.
|
||||
m := getNetworkSetupHardwareReports()
|
||||
hardwarePort, ok := m[ifaceName]
|
||||
if !ok {
|
||||
@@ -48,6 +46,10 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||
return getHardwarePortInfo(hardwarePort)
|
||||
}
|
||||
|
||||
// hardwareReportsReg is the regular expression matching the lines of
|
||||
// networksetup command output lines containing the interface information.
|
||||
var hardwareReportsReg = regexp.MustCompile("Hardware Port: (.*?)\nDevice: (.*?)\n")
|
||||
|
||||
// getNetworkSetupHardwareReports parses the output of the `networksetup
|
||||
// -listallhardwareports` command it returns a map where the key is the
|
||||
// interface name, and the value is the "hardware port" returns nil if it fails
|
||||
@@ -56,54 +58,44 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||
// TODO(e.burkov): There should be more proper approach than parsing the
|
||||
// command output. For example, see
|
||||
// https://developer.apple.com/documentation/systemconfiguration.
|
||||
func getNetworkSetupHardwareReports() map[string]string {
|
||||
_, out, err := aghos.RunCommand("networksetup", "-listallhardwareports")
|
||||
func getNetworkSetupHardwareReports() (reports map[string]string) {
|
||||
_, out, err := aghosRunCommand("networksetup", "-listallhardwareports")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n")
|
||||
if err != nil {
|
||||
return nil
|
||||
reports = make(map[string]string)
|
||||
|
||||
matches := hardwareReportsReg.FindAllSubmatch(out, -1)
|
||||
for _, m := range matches {
|
||||
reports[string(m[2])] = string(m[1])
|
||||
}
|
||||
|
||||
m := make(map[string]string)
|
||||
|
||||
matches := re.FindAllStringSubmatch(out, -1)
|
||||
for i := range matches {
|
||||
port := matches[i][1]
|
||||
device := matches[i][2]
|
||||
m[device] = port
|
||||
}
|
||||
|
||||
return m
|
||||
return reports
|
||||
}
|
||||
|
||||
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||
h := hardwarePortInfo{}
|
||||
// hardwarePortReg is the regular expression matching the lines of networksetup
|
||||
// command output lines containing the port information.
|
||||
var hardwarePortReg = regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
|
||||
|
||||
_, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||
func getHardwarePortInfo(hardwarePort string) (h hardwarePortInfo, err error) {
|
||||
_, out, err := aghosRunCommand("networksetup", "-getinfo", hardwarePort)
|
||||
if err != nil {
|
||||
return h, err
|
||||
}
|
||||
|
||||
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
|
||||
|
||||
match := re.FindStringSubmatch(out)
|
||||
if len(match) == 0 {
|
||||
match := hardwarePortReg.FindSubmatch(out)
|
||||
if len(match) != 4 {
|
||||
return h, errors.Error("could not find hardware port info")
|
||||
}
|
||||
|
||||
h.name = hardwarePort
|
||||
h.ip = match[1]
|
||||
h.subnet = match[2]
|
||||
h.gatewayIP = match[3]
|
||||
|
||||
if strings.Index(out, "Manual Configuration") == 0 {
|
||||
h.static = true
|
||||
}
|
||||
|
||||
return h, nil
|
||||
return hardwarePortInfo{
|
||||
name: hardwarePort,
|
||||
ip: string(match[1]),
|
||||
subnet: string(match[2]),
|
||||
gatewayIP: string(match[3]),
|
||||
static: bytes.Index(out, []byte("Manual Configuration")) == 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
@@ -113,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
}
|
||||
|
||||
if portInfo.static {
|
||||
return errors.Error("IP address is already static")
|
||||
return errors.Error("ip address is already static")
|
||||
}
|
||||
|
||||
dnsAddrs, err := getEtcResolvConfServers()
|
||||
@@ -121,50 +113,62 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
args := make([]string, 0)
|
||||
args = append(args, "-setdnsservers", portInfo.name)
|
||||
args = append(args, dnsAddrs...)
|
||||
args := append([]string{"-setdnsservers", portInfo.name}, dnsAddrs...)
|
||||
|
||||
// Setting DNS servers is necessary when configuring a static IP
|
||||
code, _, err := aghos.RunCommand("networksetup", args...)
|
||||
code, _, err := aghosRunCommand("networksetup", args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code != 0 {
|
||||
} else if code != 0 {
|
||||
return fmt.Errorf("failed to set DNS servers, code=%d", code)
|
||||
}
|
||||
|
||||
// Actually configures hardware port to have static IP
|
||||
code, _, err = aghos.RunCommand("networksetup", "-setmanual",
|
||||
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
||||
code, _, err = aghosRunCommand(
|
||||
"networksetup",
|
||||
"-setmanual",
|
||||
portInfo.name,
|
||||
portInfo.ip,
|
||||
portInfo.subnet,
|
||||
portInfo.gatewayIP,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code != 0 {
|
||||
} else if code != 0 {
|
||||
return fmt.Errorf("failed to set DNS servers, code=%d", code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// etcResolvConfReg is the regular expression matching the lines of resolv.conf
|
||||
// file containing a name server information.
|
||||
var etcResolvConfReg = regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
|
||||
|
||||
// getEtcResolvConfServers returns a list of nameservers configured in
|
||||
// /etc/resolv.conf.
|
||||
func getEtcResolvConfServers() ([]string, error) {
|
||||
body, err := os.ReadFile("/etc/resolv.conf")
|
||||
func getEtcResolvConfServers() (addrs []string, err error) {
|
||||
const filename = "etc/resolv.conf"
|
||||
|
||||
_, err = aghos.FileWalker(func(r io.Reader) (_ []string, _ bool, err error) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
matches := etcResolvConfReg.FindAllStringSubmatch(sc.Text(), -1)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range matches {
|
||||
addrs = append(addrs, m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, sc.Err()
|
||||
}).Walk(rootDirFS, filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
|
||||
|
||||
matches := re.FindAllStringSubmatch(string(body), -1)
|
||||
if len(matches) == 0 {
|
||||
return nil, errors.Error("found no DNS servers in /etc/resolv.conf")
|
||||
}
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for i := range matches {
|
||||
addrs = append(addrs, matches[i][1])
|
||||
return nil, fmt.Errorf("parsing etc/resolv.conf file: %w", err)
|
||||
} else if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("found no dns servers in %s", filename)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
|
||||
261
internal/aghnet/net_darwin_test.go
Normal file
261
internal/aghnet/net_darwin_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIfaceHasStaticIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
shell mapShell
|
||||
ifaceName string
|
||||
wantHas assert.BoolAssertionFunc
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "success",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
ifaceName: "en0",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "success_static",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
|
||||
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
ifaceName: "en0",
|
||||
wantHas: assert.True,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "reports_error",
|
||||
shell: theOnlyCmd(
|
||||
"networksetup -listallhardwareports",
|
||||
0,
|
||||
"",
|
||||
errors.Error("can't list"),
|
||||
),
|
||||
ifaceName: "en0",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: `could not find hardware port for en0`,
|
||||
}, {
|
||||
name: "port_error",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: errors.Error("can't get"),
|
||||
out: ``,
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
ifaceName: "en0",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: `can't get`,
|
||||
}, {
|
||||
name: "port_bad_output",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "nothing meaningful",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
ifaceName: "en0",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: `could not find hardware port info`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substShell(t, tc.shell.RunCmd)
|
||||
|
||||
has, err := IfaceHasStaticIP(tc.ifaceName)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
tc.wantHas(t, has)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIfaceSetStaticIP(t *testing.T) {
|
||||
succFsys := fstest.MapFS{
|
||||
"etc/resolv.conf": &fstest.MapFile{
|
||||
Data: []byte(`nameserver 1.1.1.1`),
|
||||
},
|
||||
}
|
||||
panicFsys := &aghtest.FS{
|
||||
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
shell mapShell
|
||||
fsys fs.FS
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "success",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -setdnsservers hwport 1.1.1.1": {
|
||||
err: nil,
|
||||
out: "",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
|
||||
err: nil,
|
||||
out: "",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
fsys: succFsys,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "static_already",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
|
||||
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
fsys: panicFsys,
|
||||
wantErrMsg: `ip address is already static`,
|
||||
}, {
|
||||
name: "reports_error",
|
||||
shell: theOnlyCmd(
|
||||
"networksetup -listallhardwareports",
|
||||
0,
|
||||
"",
|
||||
errors.Error("can't list"),
|
||||
),
|
||||
fsys: panicFsys,
|
||||
wantErrMsg: `could not find hardware port for en0`,
|
||||
}, {
|
||||
name: "resolv_conf_error",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
fsys: fstest.MapFS{
|
||||
"etc/resolv.conf": &fstest.MapFile{
|
||||
Data: []byte("this resolv.conf is invalid"),
|
||||
},
|
||||
},
|
||||
wantErrMsg: `found no dns servers in etc/resolv.conf`,
|
||||
}, {
|
||||
name: "set_dns_error",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -setdnsservers hwport 1.1.1.1": {
|
||||
err: errors.Error("can't set"),
|
||||
out: "",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
fsys: succFsys,
|
||||
wantErrMsg: `can't set`,
|
||||
}, {
|
||||
name: "set_manual_error",
|
||||
shell: mapShell{
|
||||
"networksetup -listallhardwareports": {
|
||||
err: nil,
|
||||
out: "Hardware Port: hwport\nDevice: en0\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -getinfo hwport": {
|
||||
err: nil,
|
||||
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -setdnsservers hwport 1.1.1.1": {
|
||||
err: nil,
|
||||
out: "",
|
||||
code: 0,
|
||||
},
|
||||
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
|
||||
err: errors.Error("can't set"),
|
||||
out: "",
|
||||
code: 0,
|
||||
},
|
||||
},
|
||||
fsys: succFsys,
|
||||
wantErrMsg: `can't set`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substShell(t, tc.shell.RunCmd)
|
||||
substRootDirFS(t, tc.fsys)
|
||||
|
||||
err := IfaceSetStaticIP("en0")
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,16 +13,12 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
const rcConfFilename = "etc/rc.conf"
|
||||
|
||||
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
|
||||
|
||||
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
|
||||
return walker.Walk(rootDirFS, rcConfFilename)
|
||||
}
|
||||
|
||||
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to
|
||||
|
||||
@@ -4,56 +4,74 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRcConfStaticConfig(t *testing.T) {
|
||||
const iface interfaceName = `em0`
|
||||
const nl = "\n"
|
||||
func TestIfaceHasStaticIP(t *testing.T) {
|
||||
const (
|
||||
ifaceName = `em0`
|
||||
rcConf = "etc/rc.conf"
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rcconfData string
|
||||
wantCont bool
|
||||
name string
|
||||
rootFsys fs.FS
|
||||
wantHas assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "simple",
|
||||
rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
|
||||
wantCont: false,
|
||||
name: "simple",
|
||||
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
|
||||
Data: []byte(`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl),
|
||||
}},
|
||||
wantHas: assert.True,
|
||||
}, {
|
||||
name: "case_insensitiveness",
|
||||
rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl,
|
||||
wantCont: false,
|
||||
name: "case_insensitiveness",
|
||||
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
|
||||
Data: []byte(`ifconfig_` + ifaceName + `="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl),
|
||||
}},
|
||||
wantHas: assert.True,
|
||||
}, {
|
||||
name: "comments_and_trash",
|
||||
rcconfData: `# comment 1` + nl +
|
||||
`` + nl +
|
||||
`# comment 2` + nl +
|
||||
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
|
||||
wantCont: false,
|
||||
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
|
||||
Data: []byte(`# comment 1` + nl +
|
||||
`` + nl +
|
||||
`# comment 2` + nl +
|
||||
`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
|
||||
),
|
||||
}},
|
||||
wantHas: assert.True,
|
||||
}, {
|
||||
name: "aliases",
|
||||
rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl +
|
||||
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
|
||||
wantCont: false,
|
||||
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
|
||||
Data: []byte(`ifconfig_` + ifaceName + `_alias="inet 127.0.0.1/24"` + nl +
|
||||
`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
|
||||
),
|
||||
}},
|
||||
wantHas: assert.True,
|
||||
}, {
|
||||
name: "incorrect_config",
|
||||
rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
|
||||
`ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl +
|
||||
`ifconfig_em0=""` + nl,
|
||||
wantCont: true,
|
||||
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
|
||||
Data: []byte(
|
||||
`ifconfig_` + ifaceName + `="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
|
||||
`ifconfig_` + ifaceName + `="inet 256.256.256.256 netmask 0xffffffff"` + nl +
|
||||
`ifconfig_` + ifaceName + `=""` + nl,
|
||||
),
|
||||
}},
|
||||
wantHas: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
r := strings.NewReader(tc.rcconfData)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, cont, err := iface.rcConfStaticConfig(r)
|
||||
substRootDirFS(t, tc.rootFsys)
|
||||
|
||||
has, err := IfaceHasStaticIP(ifaceName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantCont, cont)
|
||||
tc.wantHas(t, has)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,16 +13,44 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
|
||||
const dhcpcdConf = "etc/dhcpcd.conf"
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
res, err := unix.PrctlRetInt(
|
||||
unix.PR_CAP_AMBIENT,
|
||||
unix.PR_CAP_AMBIENT_IS_SET,
|
||||
unix.CAP_NET_BIND_SERVICE,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINVAL) {
|
||||
// Older versions of Linux kernel do not support this. Print a
|
||||
// warning and check admin rights.
|
||||
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Don't check the error because it's always nil on Linux.
|
||||
adm, _ := aghos.HaveAdminRights()
|
||||
|
||||
return res == 1 || adm, nil
|
||||
}
|
||||
|
||||
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
|
||||
// have a static IP.
|
||||
func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
ifaceFound := findIfaceLine(s, string(n))
|
||||
if !ifaceFound {
|
||||
if !findIfaceLine(s, string(n)) {
|
||||
return nil, true, s.Err()
|
||||
}
|
||||
|
||||
@@ -61,9 +89,9 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
|
||||
fields := strings.Fields(line)
|
||||
fieldsNum := len(fields)
|
||||
|
||||
// Man page interfaces(5) declares that interface definition
|
||||
// should consist of the key word "iface" followed by interface
|
||||
// name, and method at fourth field.
|
||||
// Man page interfaces(5) declares that interface definition should
|
||||
// consist of the key word "iface" followed by interface name, and
|
||||
// method at fourth field.
|
||||
if fieldsNum >= 4 &&
|
||||
fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" {
|
||||
return nil, false, nil
|
||||
@@ -78,10 +106,10 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
// TODO(a.garipov): Currently, this function returns the first
|
||||
// definitive result. So if /etc/dhcpcd.conf has a static IP while
|
||||
// /etc/network/interfaces doesn't, it will return true. Perhaps this
|
||||
// is not the most desirable behavior.
|
||||
// TODO(a.garipov): Currently, this function returns the first definitive
|
||||
// result. So if /etc/dhcpcd.conf has and /etc/network/interfaces has no
|
||||
// static IP configuration, it will return true. Perhaps this is not the
|
||||
// most desirable behavior.
|
||||
|
||||
iface := interfaceName(ifaceName)
|
||||
|
||||
@@ -90,17 +118,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
filename string
|
||||
}{{
|
||||
FileWalker: iface.dhcpcdStaticConfig,
|
||||
filename: "etc/dhcpcd.conf",
|
||||
filename: dhcpcdConf,
|
||||
}, {
|
||||
FileWalker: iface.ifacesStaticConfig,
|
||||
filename: "etc/network/interfaces",
|
||||
}} {
|
||||
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
|
||||
has, err = pair.Walk(rootDirFS, pair.filename)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if has {
|
||||
} else if has {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
@@ -108,14 +134,6 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
return false, ErrNoStaticIPInfo
|
||||
}
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
cnbs, err := unix.PrctlRetInt(unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, unix.CAP_NET_BIND_SERVICE, 0, 0)
|
||||
// Don't check the error because it's always nil on Linux.
|
||||
adm, _ := aghos.HaveAdminRights()
|
||||
|
||||
return cnbs == 1 || adm, err
|
||||
}
|
||||
|
||||
// findIfaceLine scans s until it finds the line that declares an interface with
|
||||
// the given name. If findIfaceLine can't find the line, it returns false.
|
||||
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
|
||||
@@ -131,23 +149,23 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
|
||||
}
|
||||
|
||||
// ifaceSetStaticIP configures the system to retain its current IP on the
|
||||
// interface through dhcpdc.conf.
|
||||
// interface through dhcpcd.conf.
|
||||
func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
ipNet := GetSubnet(ifaceName)
|
||||
if ipNet.IP == nil {
|
||||
return errors.Error("can't get IP address")
|
||||
}
|
||||
|
||||
gatewayIP := GatewayIP(ifaceName)
|
||||
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP)
|
||||
|
||||
body, err := os.ReadFile("/etc/dhcpcd.conf")
|
||||
body, err := os.ReadFile(dhcpcdConf)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
gatewayIP := GatewayIP(ifaceName)
|
||||
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
|
||||
|
||||
body = append(body, []byte(add)...)
|
||||
err = maybe.WriteFile("/etc/dhcpcd.conf", body, 0o644)
|
||||
err = maybe.WriteFile(dhcpcdConf, body, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing conf: %w", err)
|
||||
}
|
||||
@@ -157,22 +175,24 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
|
||||
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
|
||||
// configure the interface to have a static IP.
|
||||
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gatewayIP, dnsIP net.IP) (conf string) {
|
||||
var body []byte
|
||||
|
||||
add := fmt.Sprintf(
|
||||
"\n# %[1]s added by AdGuard Home.\ninterface %[1]s\nstatic ip_address=%s\n",
|
||||
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
|
||||
b := &strings.Builder{}
|
||||
stringutil.WriteToBuilder(
|
||||
b,
|
||||
"\n# ",
|
||||
ifaceName,
|
||||
ipNet)
|
||||
body = append(body, []byte(add)...)
|
||||
" added by AdGuard Home.\ninterface ",
|
||||
ifaceName,
|
||||
"\nstatic ip_address=",
|
||||
ipNet.String(),
|
||||
"\n",
|
||||
)
|
||||
|
||||
if gatewayIP != nil {
|
||||
add = fmt.Sprintf("static routers=%s\n", gatewayIP)
|
||||
body = append(body, []byte(add)...)
|
||||
if gwIP != nil {
|
||||
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
|
||||
}
|
||||
|
||||
add = fmt.Sprintf("static domain_name_servers=%s\n\n", dnsIP)
|
||||
body = append(body, []byte(add)...)
|
||||
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
|
||||
|
||||
return string(body)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -4,152 +4,124 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDHCPCDStaticConfig(t *testing.T) {
|
||||
const iface interfaceName = `wlan0`
|
||||
func TestHasStaticIP(t *testing.T) {
|
||||
const ifaceName = "wlan0"
|
||||
|
||||
const (
|
||||
dhcpcd = "etc/dhcpcd.conf"
|
||||
netifaces = "etc/network/interfaces"
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantCont bool
|
||||
}{{
|
||||
name: "has_not",
|
||||
data: []byte(`#comment` + nl +
|
||||
`# comment` + nl +
|
||||
`interface eth0` + nl +
|
||||
`static ip_address=192.168.0.1/24` + nl +
|
||||
`# interface ` + iface + nl +
|
||||
`static ip_address=192.168.1.1/24` + nl +
|
||||
`# comment` + nl,
|
||||
),
|
||||
wantCont: true,
|
||||
}, {
|
||||
name: "has",
|
||||
data: []byte(`#comment` + nl +
|
||||
`# comment` + nl +
|
||||
`interface eth0` + nl +
|
||||
`static ip_address=192.168.0.1/24` + nl +
|
||||
`# interface ` + iface + nl +
|
||||
`static ip_address=192.168.1.1/24` + nl +
|
||||
`# comment` + nl +
|
||||
`interface ` + iface + nl +
|
||||
`# comment` + nl +
|
||||
`static ip_address=192.168.2.1/24` + nl,
|
||||
),
|
||||
wantCont: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := bytes.NewReader(tc.data)
|
||||
_, cont, err := iface.dhcpcdStaticConfig(r)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantCont, cont)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIfacesStaticConfig(t *testing.T) {
|
||||
const iface interfaceName = `enp0s3`
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantCont bool
|
||||
wantPatterns []string
|
||||
}{{
|
||||
name: "has_not",
|
||||
data: []byte(`allow-hotplug ` + iface + nl +
|
||||
`#iface enp0s3 inet static` + nl +
|
||||
`# address 192.168.0.200` + nl +
|
||||
`# netmask 255.255.255.0` + nl +
|
||||
`# gateway 192.168.0.1` + nl +
|
||||
`iface ` + iface + ` inet dhcp` + nl,
|
||||
),
|
||||
wantCont: true,
|
||||
wantPatterns: []string{},
|
||||
}, {
|
||||
name: "has",
|
||||
data: []byte(`allow-hotplug ` + iface + nl +
|
||||
`iface ` + iface + ` inet static` + nl +
|
||||
` address 192.168.0.200` + nl +
|
||||
` netmask 255.255.255.0` + nl +
|
||||
` gateway 192.168.0.1` + nl +
|
||||
`#iface ` + iface + ` inet dhcp` + nl,
|
||||
),
|
||||
wantCont: false,
|
||||
wantPatterns: []string{},
|
||||
}, {
|
||||
name: "return_patterns",
|
||||
data: []byte(`source hello` + nl +
|
||||
`source world` + nl +
|
||||
`#iface ` + iface + ` inet static` + nl,
|
||||
),
|
||||
wantCont: true,
|
||||
wantPatterns: []string{"hello", "world"},
|
||||
}, {
|
||||
// This one tests if the first found valid interface prevents
|
||||
// checking files under the `source` directive.
|
||||
name: "ignore_patterns",
|
||||
data: []byte(`source hello` + nl +
|
||||
`source world` + nl +
|
||||
`iface ` + iface + ` inet static` + nl,
|
||||
),
|
||||
wantCont: false,
|
||||
wantPatterns: []string{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
r := bytes.NewReader(tc.data)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patterns, has, err := iface.ifacesStaticConfig(r)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantCont, has)
|
||||
assert.ElementsMatch(t, tc.wantPatterns, patterns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStaticIPdhcpcdConf(t *testing.T) {
|
||||
testCases := []struct {
|
||||
rootFsys fs.FS
|
||||
name string
|
||||
dhcpcdConf string
|
||||
routers net.IP
|
||||
wantHas assert.BoolAssertionFunc
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "with_gateway",
|
||||
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
|
||||
`interface wlan0` + nl +
|
||||
`static ip_address=192.168.0.2/24` + nl +
|
||||
`static routers=192.168.0.1` + nl +
|
||||
`static domain_name_servers=192.168.0.2` + nl + nl,
|
||||
routers: net.IP{192, 168, 0, 1},
|
||||
rootFsys: fstest.MapFS{
|
||||
dhcpcd: &fstest.MapFile{
|
||||
Data: []byte(`#comment` + nl +
|
||||
`# comment` + nl +
|
||||
`interface eth0` + nl +
|
||||
`static ip_address=192.168.0.1/24` + nl +
|
||||
`# interface ` + ifaceName + nl +
|
||||
`static ip_address=192.168.1.1/24` + nl +
|
||||
`# comment` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
name: "dhcpcd_has_not",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: `no information about static ip`,
|
||||
}, {
|
||||
name: "without_gateway",
|
||||
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
|
||||
`interface wlan0` + nl +
|
||||
`static ip_address=192.168.0.2/24` + nl +
|
||||
`static domain_name_servers=192.168.0.2` + nl + nl,
|
||||
routers: nil,
|
||||
rootFsys: fstest.MapFS{
|
||||
dhcpcd: &fstest.MapFile{
|
||||
Data: []byte(`#comment` + nl +
|
||||
`# comment` + nl +
|
||||
`interface ` + ifaceName + nl +
|
||||
`static ip_address=192.168.0.1/24` + nl +
|
||||
`# interface ` + ifaceName + nl +
|
||||
`static ip_address=192.168.1.1/24` + nl +
|
||||
`# comment` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
name: "dhcpcd_has",
|
||||
wantHas: assert.True,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
rootFsys: fstest.MapFS{
|
||||
netifaces: &fstest.MapFile{
|
||||
Data: []byte(`allow-hotplug ` + ifaceName + nl +
|
||||
`#iface enp0s3 inet static` + nl +
|
||||
`# address 192.168.0.200` + nl +
|
||||
`# netmask 255.255.255.0` + nl +
|
||||
`# gateway 192.168.0.1` + nl +
|
||||
`iface ` + ifaceName + ` inet dhcp` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
name: "netifaces_has_not",
|
||||
wantHas: assert.False,
|
||||
wantErrMsg: `no information about static ip`,
|
||||
}, {
|
||||
rootFsys: fstest.MapFS{
|
||||
netifaces: &fstest.MapFile{
|
||||
Data: []byte(`allow-hotplug ` + ifaceName + nl +
|
||||
`iface ` + ifaceName + ` inet static` + nl +
|
||||
` address 192.168.0.200` + nl +
|
||||
` netmask 255.255.255.0` + nl +
|
||||
` gateway 192.168.0.1` + nl +
|
||||
`#iface ` + ifaceName + ` inet dhcp` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
name: "netifaces_has",
|
||||
wantHas: assert.True,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
rootFsys: fstest.MapFS{
|
||||
netifaces: &fstest.MapFile{
|
||||
Data: []byte(`source hello` + nl +
|
||||
`#iface ` + ifaceName + ` inet static` + nl,
|
||||
),
|
||||
},
|
||||
"hello": &fstest.MapFile{
|
||||
Data: []byte(`iface ` + ifaceName + ` inet static` + nl),
|
||||
},
|
||||
},
|
||||
name: "netifaces_another_file",
|
||||
wantHas: assert.True,
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
rootFsys: fstest.MapFS{
|
||||
netifaces: &fstest.MapFile{
|
||||
Data: []byte(`source hello` + nl +
|
||||
`iface ` + ifaceName + ` inet static` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
name: "netifaces_ignore_another",
|
||||
wantHas: assert.True,
|
||||
wantErrMsg: ``,
|
||||
}}
|
||||
|
||||
ipNet := &net.IPNet{
|
||||
IP: net.IP{192, 168, 0, 2},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := dhcpcdConfIface("wlan0", ipNet, tc.routers, net.IP{192, 168, 0, 2})
|
||||
assert.Equal(t, tc.dhcpcdConf, s)
|
||||
substRootDirFS(t, tc.rootFsys)
|
||||
|
||||
has, err := IfaceHasStaticIP(ifaceName)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
tc.wantHas(t, has)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,14 +13,10 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
|
||||
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
|
||||
return aghos.FileWalker(hostnameIfStaticConfig).Walk(rootDirFS, filename)
|
||||
}
|
||||
|
||||
// hostnameIfStaticConfig checks if the interface is configured by
|
||||
|
||||
@@ -4,49 +4,69 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostnameIfStaticConfig(t *testing.T) {
|
||||
const nl = "\n"
|
||||
func TestIfaceHasStaticIP(t *testing.T) {
|
||||
const ifaceName = "em0"
|
||||
|
||||
confFile := fmt.Sprintf("etc/hostname.%s", ifaceName)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rcconfData string
|
||||
wantHas bool
|
||||
name string
|
||||
rootFsys fs.FS
|
||||
wantHas assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "simple",
|
||||
rcconfData: `inet 127.0.0.253` + nl,
|
||||
wantHas: true,
|
||||
name: "simple",
|
||||
rootFsys: fstest.MapFS{
|
||||
confFile: &fstest.MapFile{
|
||||
Data: []byte(`inet 127.0.0.253` + nl),
|
||||
},
|
||||
},
|
||||
wantHas: assert.True,
|
||||
}, {
|
||||
name: "case_sensitiveness",
|
||||
rcconfData: `InEt 127.0.0.253` + nl,
|
||||
wantHas: false,
|
||||
name: "case_sensitiveness",
|
||||
rootFsys: fstest.MapFS{
|
||||
confFile: &fstest.MapFile{
|
||||
Data: []byte(`InEt 127.0.0.253` + nl),
|
||||
},
|
||||
},
|
||||
wantHas: assert.False,
|
||||
}, {
|
||||
name: "comments_and_trash",
|
||||
rcconfData: `# comment 1` + nl +
|
||||
`` + nl +
|
||||
`# inet 127.0.0.253` + nl +
|
||||
`inet` + nl,
|
||||
wantHas: false,
|
||||
rootFsys: fstest.MapFS{
|
||||
confFile: &fstest.MapFile{
|
||||
Data: []byte(`# comment 1` + nl + nl +
|
||||
`# inet 127.0.0.253` + nl +
|
||||
`inet` + nl,
|
||||
),
|
||||
},
|
||||
},
|
||||
wantHas: assert.False,
|
||||
}, {
|
||||
name: "incorrect_config",
|
||||
rcconfData: `inet6 127.0.0.253` + nl +
|
||||
`inet 256.256.256.256` + nl,
|
||||
wantHas: false,
|
||||
rootFsys: fstest.MapFS{
|
||||
confFile: &fstest.MapFile{
|
||||
Data: []byte(`inet6 127.0.0.253` + nl + `inet 256.256.256.256` + nl),
|
||||
},
|
||||
},
|
||||
wantHas: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
r := strings.NewReader(tc.rcconfData)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, has, err := hostnameIfStaticConfig(r)
|
||||
substRootDirFS(t, tc.rootFsys)
|
||||
|
||||
has, err := IfaceHasStaticIP(ifaceName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantHas, has)
|
||||
tc.wantHas(t, has)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,138 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestGetInterfaceByIP(t *testing.T) {
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
|
||||
// package with fsys for tests ran under t.
|
||||
func substRootDirFS(t testing.TB, fsys fs.FS) {
|
||||
t.Helper()
|
||||
|
||||
prev := rootDirFS
|
||||
t.Cleanup(func() { rootDirFS = prev })
|
||||
rootDirFS = fsys
|
||||
}
|
||||
|
||||
// RunCmdFunc is the signature of aghos.RunCommand function.
|
||||
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
||||
|
||||
// substShell replaces the the aghos.RunCommand function used throughout the
|
||||
// package with rc for tests ran under t.
|
||||
func substShell(t testing.TB, rc RunCmdFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := aghosRunCommand
|
||||
t.Cleanup(func() { aghosRunCommand = prev })
|
||||
aghosRunCommand = rc
|
||||
}
|
||||
|
||||
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
||||
// execution result. It's only needed to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
||||
type mapShell map[string]struct {
|
||||
err error
|
||||
out string
|
||||
code int
|
||||
}
|
||||
|
||||
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
||||
// combination from cmd.
|
||||
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
||||
return mapShell{cmd: {code: code, out: out, err: err}}
|
||||
}
|
||||
|
||||
// RunCmd is a RunCmdFunc handled by s.
|
||||
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
||||
key := strings.Join(append([]string{cmd}, args...), " ")
|
||||
ret, ok := s[key]
|
||||
if !ok {
|
||||
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
||||
}
|
||||
|
||||
return ret.code, []byte(ret.out), ret.err
|
||||
}
|
||||
|
||||
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
|
||||
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
|
||||
|
||||
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
|
||||
// throughout the package with f for tests ran under t.
|
||||
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := netInterfaceAddrs
|
||||
t.Cleanup(func() { netInterfaceAddrs = prev })
|
||||
netInterfaceAddrs = f
|
||||
}
|
||||
|
||||
func TestGatewayIP(t *testing.T) {
|
||||
const ifaceName = "ifaceName"
|
||||
const cmd = "ip route show dev " + ifaceName
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
shell mapShell
|
||||
want net.IP
|
||||
}{{
|
||||
name: "success_v4",
|
||||
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
||||
want: net.IP{1, 2, 3, 4}.To16(),
|
||||
}, {
|
||||
name: "success_v6",
|
||||
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
||||
want: net.IP{
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0xFF, 0xFF,
|
||||
},
|
||||
}, {
|
||||
name: "bad_output",
|
||||
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "err_runcmd",
|
||||
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "bad_code",
|
||||
shell: theOnlyCmd(cmd, 1, "", nil),
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substShell(t, tc.shell.RunCmd)
|
||||
|
||||
assert.Equal(t, tc.want, GatewayIP(ifaceName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
@@ -30,7 +142,7 @@ func TestGetInterfaceByIP(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := GetInterfaceByIP(ip)
|
||||
ifaceName := InterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
@@ -130,3 +242,107 @@ func TestCheckPort(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
addrs []net.Addr
|
||||
wantAddrs []string
|
||||
}{{
|
||||
name: "success",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{&net.IPNet{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
|
||||
}, &net.IPNet{
|
||||
IP: net.IP{4, 3, 2, 1},
|
||||
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
||||
}},
|
||||
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
||||
}, {
|
||||
name: "not_cidr",
|
||||
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
||||
addrs: []net.Addr{&net.IPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
}},
|
||||
wantAddrs: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{},
|
||||
wantAddrs: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
|
||||
|
||||
addrs, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.wantAddrs, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("internal_error", func(t *testing.T) {
|
||||
const errAddrs errors.Error = "can't get addresses"
|
||||
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
||||
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
||||
|
||||
_, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsAddrInUse(t *testing.T) {
|
||||
t.Run("addr_in_use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
|
||||
assert.True(t, IsAddrInUse(err))
|
||||
})
|
||||
|
||||
t.Run("another", func(t *testing.T) {
|
||||
const anotherErr errors.Error = "not addr in use"
|
||||
|
||||
assert.False(t, IsAddrInUse(anotherErr))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetInterface_MarshalJSON(t *testing.T) {
|
||||
const want = `{` +
|
||||
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
|
||||
`"flags":"up|multicast",` +
|
||||
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
|
||||
`"name":"iface0",` +
|
||||
`"mtu":1500` +
|
||||
`}` + "\n"
|
||||
|
||||
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
|
||||
|
||||
iface := &NetInterface{
|
||||
Addresses: []net.IP{ip4, ip6},
|
||||
Subnets: []*net.IPNet{{
|
||||
IP: ip4.Mask(mask4),
|
||||
Mask: mask4,
|
||||
}, {
|
||||
IP: ip6.Mask(mask6),
|
||||
Mask: mask6,
|
||||
}},
|
||||
Name: "iface0",
|
||||
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
Flags: net.FlagUp | net.FlagMulticast,
|
||||
MTU: 1500,
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err := json.NewEncoder(b).Encode(iface)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, b.String())
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//go:build !(linux || darwin || freebsd || openbsd)
|
||||
// +build !linux,!darwin,!freebsd,!openbsd
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
return aghos.HaveAdminRights()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ifaceHasStaticIP(string) (ok bool, err error) {
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// SubnetDetector describes IP address properties.
|
||||
type SubnetDetector struct {
|
||||
// spNets is the collection of special-purpose address registries as defined
|
||||
// by RFC 6890.
|
||||
spNets []*net.IPNet
|
||||
|
||||
// locServedNets is the collection of locally-served networks as defined by
|
||||
// RFC 6303.
|
||||
locServedNets []*net.IPNet
|
||||
}
|
||||
|
||||
// NewSubnetDetector returns a new IP detector.
|
||||
//
|
||||
// TODO(a.garipov): Decide whether an error is actually needed.
|
||||
func NewSubnetDetector() (snd *SubnetDetector, err error) {
|
||||
spNets := []string{
|
||||
// "This" network.
|
||||
"0.0.0.0/8",
|
||||
// Private-Use Networks.
|
||||
"10.0.0.0/8",
|
||||
// Shared Address Space.
|
||||
"100.64.0.0/10",
|
||||
// Loopback.
|
||||
"127.0.0.0/8",
|
||||
// Link Local.
|
||||
"169.254.0.0/16",
|
||||
// Private-Use Networks.
|
||||
"172.16.0.0/12",
|
||||
// IETF Protocol Assignments.
|
||||
"192.0.0.0/24",
|
||||
// DS-Lite.
|
||||
"192.0.0.0/29",
|
||||
// TEST-NET-1
|
||||
"192.0.2.0/24",
|
||||
// 6to4 Relay Anycast.
|
||||
"192.88.99.0/24",
|
||||
// Private-Use Networks.
|
||||
"192.168.0.0/16",
|
||||
// Network Interconnect Device Benchmark Testing.
|
||||
"198.18.0.0/15",
|
||||
// TEST-NET-2.
|
||||
"198.51.100.0/24",
|
||||
// TEST-NET-3.
|
||||
"203.0.113.0/24",
|
||||
// Reserved for Future Use.
|
||||
"240.0.0.0/4",
|
||||
// Limited Broadcast.
|
||||
"255.255.255.255/32",
|
||||
|
||||
// Loopback.
|
||||
"::1/128",
|
||||
// Unspecified.
|
||||
"::/128",
|
||||
// IPv4-IPv6 Translation Address.
|
||||
"64:ff9b::/96",
|
||||
|
||||
// IPv4-Mapped Address. Since this network is used for mapping
|
||||
// IPv4 addresses, we don't include it.
|
||||
// "::ffff:0:0/96",
|
||||
|
||||
// Discard-Only Prefix.
|
||||
"100::/64",
|
||||
// IETF Protocol Assignments.
|
||||
"2001::/23",
|
||||
// TEREDO.
|
||||
"2001::/32",
|
||||
// Benchmarking.
|
||||
"2001:2::/48",
|
||||
// Documentation.
|
||||
"2001:db8::/32",
|
||||
// ORCHID.
|
||||
"2001:10::/28",
|
||||
// 6to4.
|
||||
"2002::/16",
|
||||
// Unique-Local.
|
||||
"fc00::/7",
|
||||
// Linked-Scoped Unicast.
|
||||
"fe80::/10",
|
||||
}
|
||||
|
||||
// TODO(e.burkov): It's a subslice of the slice above. Should be done
|
||||
// smarter.
|
||||
locServedNets := []string{
|
||||
// IPv4.
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"192.0.2.0/24",
|
||||
"198.51.100.0/24",
|
||||
"203.0.113.0/24",
|
||||
"255.255.255.255/32",
|
||||
// IPv6.
|
||||
"::/128",
|
||||
"::1/128",
|
||||
"fe80::/10",
|
||||
"2001:db8::/32",
|
||||
"fd00::/8",
|
||||
}
|
||||
|
||||
snd = &SubnetDetector{
|
||||
spNets: make([]*net.IPNet, len(spNets)),
|
||||
locServedNets: make([]*net.IPNet, len(locServedNets)),
|
||||
}
|
||||
for i, ipnetStr := range spNets {
|
||||
var ipnet *net.IPNet
|
||||
_, ipnet, err = net.ParseCIDR(ipnetStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
snd.spNets[i] = ipnet
|
||||
}
|
||||
for i, ipnetStr := range locServedNets {
|
||||
var ipnet *net.IPNet
|
||||
_, ipnet, err = net.ParseCIDR(ipnetStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
snd.locServedNets[i] = ipnet
|
||||
}
|
||||
|
||||
return snd, nil
|
||||
}
|
||||
|
||||
// anyNetContains ranges through the given ipnets slice searching for the one
|
||||
// which contains the ip. For internal use only.
|
||||
//
|
||||
// TODO(e.burkov): Think about memoization.
|
||||
func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) {
|
||||
for _, ipnet := range *ipnets {
|
||||
if ipnet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSpecialNetwork returns true if IP address is contained by any of
|
||||
// special-purpose IP address registries. It's safe for concurrent use.
|
||||
func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) {
|
||||
return anyNetContains(&snd.spNets, ip)
|
||||
}
|
||||
|
||||
// IsLocallyServedNetwork returns true if IP address is contained by any of
|
||||
// locally-served IP address registries. It's safe for concurrent use.
|
||||
func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) {
|
||||
return anyNetContains(&snd.locServedNets, ip)
|
||||
}
|
||||
@@ -1,252 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) {
|
||||
snd, err := NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
want bool
|
||||
}{{
|
||||
name: "not_specific",
|
||||
ip: net.ParseIP("8.8.8.8"),
|
||||
want: false,
|
||||
}, {
|
||||
name: "this_host_on_this_network",
|
||||
ip: net.ParseIP("0.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "private-Use",
|
||||
ip: net.ParseIP("10.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "shared_address_space",
|
||||
ip: net.ParseIP("100.64.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "loopback",
|
||||
ip: net.ParseIP("127.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "link_local",
|
||||
ip: net.ParseIP("169.254.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "private-use",
|
||||
ip: net.ParseIP("172.16.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "ietf_protocol_assignments",
|
||||
ip: net.ParseIP("192.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "ds-lite",
|
||||
ip: net.ParseIP("192.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-1)",
|
||||
ip: net.ParseIP("192.0.2.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "6to4_relay_anycast",
|
||||
ip: net.ParseIP("192.88.99.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "private-use",
|
||||
ip: net.ParseIP("192.168.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "benchmarking",
|
||||
ip: net.ParseIP("198.18.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-2)",
|
||||
ip: net.ParseIP("198.51.100.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-3)",
|
||||
ip: net.ParseIP("203.0.113.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "reserved",
|
||||
ip: net.ParseIP("240.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "limited_broadcast",
|
||||
ip: net.ParseIP("255.255.255.255"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "loopback_address",
|
||||
ip: net.ParseIP("::1"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "unspecified_address",
|
||||
ip: net.ParseIP("::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "ipv4-ipv6_translation",
|
||||
ip: net.ParseIP("64:ff9b::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "discard-only_address_block",
|
||||
ip: net.ParseIP("100::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "ietf_protocol_assignments",
|
||||
ip: net.ParseIP("2001::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "teredo",
|
||||
ip: net.ParseIP("2001::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "benchmarking",
|
||||
ip: net.ParseIP("2001:2::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation",
|
||||
ip: net.ParseIP("2001:db8::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "orchid",
|
||||
ip: net.ParseIP("2001:10::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "6to4",
|
||||
ip: net.ParseIP("2002::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "unique-local",
|
||||
ip: net.ParseIP("fc00::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "linked-scoped_unicast",
|
||||
ip: net.ParseIP("fe80::"),
|
||||
want: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) {
|
||||
snd, err := NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
want bool
|
||||
}{{
|
||||
name: "not_specific",
|
||||
ip: net.ParseIP("8.8.8.8"),
|
||||
want: false,
|
||||
}, {
|
||||
name: "private-Use",
|
||||
ip: net.ParseIP("10.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "loopback",
|
||||
ip: net.ParseIP("127.0.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "link_local",
|
||||
ip: net.ParseIP("169.254.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "private-use",
|
||||
ip: net.ParseIP("172.16.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-1)",
|
||||
ip: net.ParseIP("192.0.2.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "private-use",
|
||||
ip: net.ParseIP("192.168.0.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-2)",
|
||||
ip: net.ParseIP("198.51.100.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation_(test-net-3)",
|
||||
ip: net.ParseIP("203.0.113.0"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "limited_broadcast",
|
||||
ip: net.ParseIP("255.255.255.255"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "loopback_address",
|
||||
ip: net.ParseIP("::1"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "unspecified_address",
|
||||
ip: net.ParseIP("::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "documentation",
|
||||
ip: net.ParseIP("2001:db8::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "linked-scoped_unicast",
|
||||
ip: net.ParseIP("fe80::"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "locally_assigned",
|
||||
ip: net.ParseIP("fd00::1"),
|
||||
want: true,
|
||||
}, {
|
||||
name: "not_locally_assigned",
|
||||
ip: net.ParseIP("fc00::1"),
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubnetDetector_Detect_parallel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
snd, err := NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
|
||||
testFunc := func() {
|
||||
for _, ip := range []net.IP{
|
||||
net.IPv4allrouter,
|
||||
net.IPv4allsys,
|
||||
net.IPv4bcast,
|
||||
net.IPv4zero,
|
||||
net.IPv6interfacelocalallnodes,
|
||||
net.IPv6linklocalallnodes,
|
||||
net.IPv6linklocalallrouters,
|
||||
net.IPv6loopback,
|
||||
net.IPv6unspecified,
|
||||
} {
|
||||
_ = snd.IsSpecialNetwork(ip)
|
||||
_ = snd.IsLocallyServedNetwork(ip)
|
||||
}
|
||||
}
|
||||
|
||||
const goroutinesNum = 50
|
||||
for i := 0; i < goroutinesNum; i++ {
|
||||
go testFunc()
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,5 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DefaultRefreshIvl is the default period of time between refreshing cached
|
||||
// addresses.
|
||||
// const DefaultRefreshIvl = 5 * time.Minute
|
||||
@@ -16,39 +10,21 @@ type HostGenFunc func() (host string)
|
||||
|
||||
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
|
||||
type SystemResolvers interface {
|
||||
// Get returns the slice of local resolvers' addresses. It should be
|
||||
// safe for concurrent use.
|
||||
// Get returns the slice of local resolvers' addresses. It must be safe for
|
||||
// concurrent use.
|
||||
Get() (rs []string)
|
||||
// refresh refreshes the local resolvers' addresses cache. It should be
|
||||
// safe for concurrent use.
|
||||
// refresh refreshes the local resolvers' addresses cache. It must be safe
|
||||
// for concurrent use.
|
||||
refresh() (err error)
|
||||
}
|
||||
|
||||
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
||||
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
||||
defer log.OnPanic("systemResolvers")
|
||||
|
||||
// TODO(e.burkov): Implement a functionality to stop ticker.
|
||||
for range tickCh {
|
||||
err := sr.refresh()
|
||||
if err != nil {
|
||||
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("systemResolvers: local addresses cache is refreshed")
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
|
||||
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
refreshIvl time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers, err error) {
|
||||
sr = newSystemResolvers(refreshIvl, hostGenFunc)
|
||||
sr = newSystemResolvers(hostGenFunc)
|
||||
|
||||
// Fill cache.
|
||||
err = sr.refresh()
|
||||
@@ -56,11 +32,5 @@ func NewSystemResolvers(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if refreshIvl > 0 {
|
||||
ticker := time.NewTicker(refreshIvl)
|
||||
|
||||
go refreshWithTicker(sr, ticker.C)
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
@@ -24,12 +24,15 @@ func defaultHostGen() (host string) {
|
||||
|
||||
// systemResolvers is a default implementation of SystemResolvers interface.
|
||||
type systemResolvers struct {
|
||||
resolver *net.Resolver
|
||||
hostGenFunc HostGenFunc
|
||||
|
||||
// addrs is the set that contains cached local resolvers' addresses.
|
||||
addrs *stringutil.Set
|
||||
// addrsLock protects addrs.
|
||||
addrsLock sync.RWMutex
|
||||
// addrs is the set that contains cached local resolvers' addresses.
|
||||
addrs *stringutil.Set
|
||||
|
||||
// resolver is used to fetch the resolvers' addresses.
|
||||
resolver *net.Resolver
|
||||
// hostGenFunc generates hosts to resolve.
|
||||
hostGenFunc HostGenFunc
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -44,6 +47,7 @@ const (
|
||||
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
||||
)
|
||||
|
||||
// refresh implements the SystemResolvers interface for *systemResolvers.
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
@@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||
if hostGenFunc == nil {
|
||||
hostGenFunc = defaultHostGen
|
||||
}
|
||||
@@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
|
||||
func validateDialedHost(host string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
|
||||
|
||||
var ipStr string
|
||||
parts := strings.Split(host, "%")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
ipStr = host
|
||||
// host
|
||||
case 2:
|
||||
// Remove the zone and check the IP address part.
|
||||
ipStr = parts[0]
|
||||
host = parts[0]
|
||||
default:
|
||||
return errUnexpectedHostFormat
|
||||
}
|
||||
|
||||
if net.ParseIP(ipStr) == nil {
|
||||
if _, err = netutil.ParseIP(host); err != nil {
|
||||
return errBadAddrPassed
|
||||
}
|
||||
|
||||
|
||||
@@ -6,37 +6,32 @@ package aghnet
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolversImp(
|
||||
func createTestSystemResolversImpl(
|
||||
t *testing.T,
|
||||
refreshDur time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (imp *systemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
|
||||
sr := createTestSystemResolvers(t, hostGenFunc)
|
||||
require.IsType(t, (*systemResolvers)(nil), sr)
|
||||
|
||||
var ok bool
|
||||
imp, ok = sr.(*systemResolvers)
|
||||
require.True(t, ok)
|
||||
|
||||
return imp
|
||||
return sr.(*systemResolvers)
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
t.Run("expected_error", func(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, 0, nil)
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
assert.NoError(t, sr.refresh())
|
||||
})
|
||||
|
||||
t.Run("unexpected_error", func(t *testing.T) {
|
||||
_, err := NewSystemResolvers(0, func() string {
|
||||
_, err := NewSystemResolvers(func() string {
|
||||
return "127.0.0.1::123"
|
||||
})
|
||||
assert.Error(t, err)
|
||||
@@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
imp := createTestSystemResolversImp(t, 0, nil)
|
||||
imp := createTestSystemResolversImpl(t, nil)
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
@@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
address string
|
||||
}{{
|
||||
want: errFakeDial,
|
||||
name: "valid",
|
||||
name: "valid_ipv4",
|
||||
address: "127.0.0.1",
|
||||
}, {
|
||||
want: errFakeDial,
|
||||
|
||||
@@ -2,7 +2,6 @@ package aghnet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -10,13 +9,12 @@ import (
|
||||
|
||||
func createTestSystemResolvers(
|
||||
t *testing.T,
|
||||
refreshDur time.Duration,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
|
||||
sr, err = NewSystemResolvers(hostGenFunc)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sr)
|
||||
|
||||
@@ -24,8 +22,14 @@ func createTestSystemResolvers(
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Get(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, 0, nil)
|
||||
assert.NotEmpty(t, sr.Get())
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
var rs []string
|
||||
require.NotPanics(t, func() {
|
||||
rs = sr.Get()
|
||||
})
|
||||
|
||||
assert.NotEmpty(t, rs)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Write tests for refreshWithTicker.
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -27,7 +26,7 @@ type systemResolvers struct {
|
||||
addrsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) {
|
||||
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
|
||||
return &systemResolvers{}
|
||||
}
|
||||
|
||||
|
||||
4
internal/aghnet/testdata/proc_net_arp
vendored
4
internal/aghnet/testdata/proc_net_arp
vendored
@@ -1,4 +1,6 @@
|
||||
IP address HW type Flags HW address Mask Device
|
||||
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
|
||||
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
|
||||
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
|
||||
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
|
||||
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
|
||||
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
|
||||
@@ -1,4 +1,4 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
57
internal/aghos/filewalker_internal_test.go
Normal file
57
internal/aghos/filewalker_internal_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// errFS is an fs.FS implementation, method Open of which always returns
|
||||
// errFSOpen.
|
||||
type errFS struct{}
|
||||
|
||||
// errFSOpen is returned from errGlobFS.Open.
|
||||
const errFSOpen errors.Error = "test open error"
|
||||
|
||||
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
|
||||
// err is always errFSOpen.
|
||||
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
|
||||
return nil, errFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
@@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
@@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,20 +57,22 @@ func HaveAdminRights() (bool, error) {
|
||||
const MaxCmdOutputSize = 64 * 1024
|
||||
|
||||
// RunCommand runs shell command.
|
||||
func RunCommand(command string, arguments ...string) (code int, output string, err error) {
|
||||
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
|
||||
cmd := exec.Command(command, arguments...)
|
||||
out, err := cmd.Output()
|
||||
if len(out) > MaxCmdOutputSize {
|
||||
out = out[:MaxCmdOutputSize]
|
||||
}
|
||||
|
||||
if errors.As(err, new(*exec.ExitError)) {
|
||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||
} else if err != nil {
|
||||
return 1, "", fmt.Errorf("command %q failed: %w: %s", command, err, out)
|
||||
if err != nil {
|
||||
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
|
||||
return eerr.ExitCode(), eerr.Stderr, nil
|
||||
}
|
||||
|
||||
return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
|
||||
}
|
||||
|
||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||
return cmd.ProcessState.ExitCode(), out, nil
|
||||
}
|
||||
|
||||
// PIDByCommand searches for process named command and returns its PID ignoring
|
||||
@@ -173,3 +175,13 @@ func RootDirFS() (fsys fs.FS) {
|
||||
// behavior is undocumented but it currently works.
|
||||
return os.DirFS("")
|
||||
}
|
||||
|
||||
// NotifyShutdownSignal notifies c on receiving shutdown signals.
|
||||
func NotifyShutdownSignal(c chan<- os.Signal) {
|
||||
notifyShutdownSignal(c)
|
||||
}
|
||||
|
||||
// IsShutdownSignal returns true if sig is a shutdown signal.
|
||||
func IsShutdownSignal(sig os.Signal) (ok bool) {
|
||||
return isShutdownSignal(sig)
|
||||
}
|
||||
|
||||
27
internal/aghos/os_unix.go
Normal file
27
internal/aghos/os_unix.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
// +build darwin freebsd linux openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case
|
||||
unix.SIGINT,
|
||||
unix.SIGQUIT,
|
||||
unix.SIGTERM:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,10 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
@@ -35,3 +39,20 @@ func haveAdminRights() (bool, error) {
|
||||
func isOpenWrt() (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
// syscall.SIGTERM is processed automatically. See go doc os/signal,
|
||||
// section Windows.
|
||||
signal.Notify(c, os.Interrupt)
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case
|
||||
os.Interrupt,
|
||||
syscall.SIGTERM:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Exchanger is a mock aghnet.Exchanger implementation for tests.
|
||||
type Exchanger struct {
|
||||
Ups upstream.Upstream
|
||||
}
|
||||
|
||||
// Exchange implements aghnet.Exchanger interface for *Exchanger.
|
||||
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
if e.Ups == nil {
|
||||
e.Ups = &TestErrUpstream{}
|
||||
}
|
||||
|
||||
return e.Ups.Exchange(req)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
135
internal/aghtest/interface.go
Normal file
135
internal/aghtest/interface.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Interface Mocks
|
||||
//
|
||||
// Keep entities in this file in alphabetic order.
|
||||
|
||||
// Standard Library
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock [fs.FS] implementation for tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the [fs.FS] interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock [fs.GlobFS] implementation for tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock [fs.StatFS] implementation for tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ net.Listener = (*Listener)(nil)
|
||||
|
||||
// Listener is a mock [net.Listener] implementation for tests.
|
||||
type Listener struct {
|
||||
OnAccept func() (conn net.Conn, err error)
|
||||
OnAddr func() (addr net.Addr)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Accept implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
return l.OnAccept()
|
||||
}
|
||||
|
||||
// Addr implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Addr() (addr net.Addr) {
|
||||
return l.OnAddr()
|
||||
}
|
||||
|
||||
// Close implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Close() (err error) {
|
||||
return l.OnClose()
|
||||
}
|
||||
|
||||
// Module dnsproxy
|
||||
|
||||
// type check
|
||||
var _ upstream.Upstream = (*UpstreamMock)(nil)
|
||||
|
||||
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
|
||||
//
|
||||
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
|
||||
// rename it to just Upstream.
|
||||
type UpstreamMock struct {
|
||||
OnAddress func() (addr string)
|
||||
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
|
||||
}
|
||||
|
||||
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Address() (addr string) {
|
||||
return u.OnAddress()
|
||||
}
|
||||
|
||||
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return u.OnExchange(req)
|
||||
}
|
||||
|
||||
// Module AdGuardHome
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
||||
|
||||
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
9
internal/aghtest/interface_test.go
Normal file
9
internal/aghtest/interface_test.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package aghtest_test
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)
|
||||
@@ -1,46 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import "io/fs"
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock fs.FS implementation to use in tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the fs.FS interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock fs.StatFS implementation to use in tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the fs.StatFS interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock fs.GlobFS implementation to use in tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the fs.GlobFS interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
@@ -6,12 +6,18 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Additional Upstream Testing Utilities
|
||||
|
||||
// Upstream is a mock implementation of upstream.Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
|
||||
type Upstream struct {
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string][]string
|
||||
@@ -25,6 +31,43 @@ type Upstream struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
// RespondTo returns a response with answer if req has class cl, question type
|
||||
// qt, and target targ.
|
||||
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Question, 1)
|
||||
|
||||
q := req.Question[0]
|
||||
targ = dns.Fqdn(targ)
|
||||
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
|
||||
return nil
|
||||
}
|
||||
|
||||
respHdr := dns.RR_Header{
|
||||
Name: targ,
|
||||
Rrtype: qt,
|
||||
Class: cl,
|
||||
Ttl: 60,
|
||||
}
|
||||
|
||||
resp = new(dns.Msg).SetReply(req)
|
||||
switch qt {
|
||||
case dns.TypePTR:
|
||||
resp.Answer = []dns.RR{
|
||||
&dns.PTR{
|
||||
Hdr: respHdr,
|
||||
Ptr: answer,
|
||||
},
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unsupported question type: %s", dns.Type(qt))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Split further into handlers.
|
||||
@@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
|
||||
return u.Addr
|
||||
}
|
||||
|
||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestBlockUpstream struct {
|
||||
Hostname string
|
||||
|
||||
// lock protects reqNum.
|
||||
lock sync.RWMutex
|
||||
reqNum int
|
||||
|
||||
Block bool
|
||||
}
|
||||
|
||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||
// pair.
|
||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.reqNum++
|
||||
|
||||
hash := sha256.Sum256([]byte(u.Hostname))
|
||||
hashToReturn := hex.EncodeToString(hash[:])
|
||||
if !u.Block {
|
||||
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
|
||||
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
|
||||
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
|
||||
// a different hash.
|
||||
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
|
||||
hash := sha256.Sum256([]byte(hostname))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
if !shouldBlock {
|
||||
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
m.Answer = []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
},
|
||||
Txt: []string{
|
||||
hashToReturn,
|
||||
},
|
||||
ans := &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "",
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Txt: []string{hashStr},
|
||||
}
|
||||
respTmpl := &dns.Msg{
|
||||
Answer: []dns.RR{ans},
|
||||
}
|
||||
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "sbpc.upstream.example"
|
||||
},
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = respTmpl.Copy()
|
||||
resp.SetReply(req)
|
||||
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestBlockUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
// ErrUpstream is the error returned from the [*UpstreamMock] created by
|
||||
// [NewErrorUpstream].
|
||||
const ErrUpstream errors.Error = "test upstream error"
|
||||
|
||||
// RequestsCount returns the number of handled requests. It's safe for
|
||||
// concurrent use.
|
||||
func (u *TestBlockUpstream) RequestsCount() int {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
return u.reqNum
|
||||
}
|
||||
|
||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestErrUpstream struct {
|
||||
// The error returned by Exchange may be unwrapped to the Err.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("errupstream: %w", u.Err)
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestErrUpstream) Address() string {
|
||||
return ""
|
||||
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
|
||||
// its Exchange method.
|
||||
func NewErrorUpstream() (u *UpstreamMock) {
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "error.upstream.example"
|
||||
},
|
||||
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return nil, errors.Error("test upstream error")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/mdlayher/ethernet"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
@@ -49,16 +51,15 @@ type dhcpConn struct {
|
||||
}
|
||||
|
||||
// newDHCPConn creates the special connection for DHCP server.
|
||||
func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) {
|
||||
// Create the raw connection.
|
||||
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
|
||||
var ucast net.PacketConn
|
||||
if ucast, err = raw.ListenPacket(ifi, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
return nil, fmt.Errorf("creating raw udp connection: %w", err)
|
||||
}
|
||||
|
||||
// Create the UDP connection.
|
||||
var bcast net.PacketConn
|
||||
bcast, err = server4.NewIPv4UDPConn(ifi.Name, &net.UDPAddr{
|
||||
bcast, err = server4.NewIPv4UDPConn(iface.Name, &net.UDPAddr{
|
||||
// TODO(e.burkov): Listening on zeroes makes the server handle
|
||||
// requests from all the interfaces. Inspect the ways to
|
||||
// specify the interface-specific listening addresses.
|
||||
@@ -75,7 +76,7 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error)
|
||||
udpConn: bcast,
|
||||
bcastIP: s.conf.broadcastIP,
|
||||
rawConn: ucast,
|
||||
srcMAC: ifi.HardwareAddr,
|
||||
srcMAC: iface.HardwareAddr,
|
||||
srcIP: s.conf.dnsIPAddrs[0],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
@@ -126,7 +126,7 @@ type ServerConfig struct {
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -145,7 +146,7 @@ type dhcpServerConfigJSON struct {
|
||||
V4 *v4ServerConfJSON `json:"v4"`
|
||||
V6 *v6ServerConfJSON `json:"v6"`
|
||||
InterfaceName string `json:"interface_name"`
|
||||
Enabled nullBool `json:"enabled"`
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV4(
|
||||
@@ -156,7 +157,7 @@ func (s *Server) handleDHCPSetConfigV4(
|
||||
}
|
||||
|
||||
v4Conf := v4JSONToServerConf(conf.V4)
|
||||
v4Conf.Enabled = conf.Enabled == nbTrue
|
||||
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v4Conf.RangeStart) == 0 {
|
||||
v4Conf.Enabled = false
|
||||
}
|
||||
@@ -183,7 +184,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
}
|
||||
|
||||
v6Conf := v6JSONToServerConf(conf.V6)
|
||||
v6Conf.Enabled = conf.Enabled == nbTrue
|
||||
v6Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v6Conf.RangeStart) == 0 {
|
||||
v6Conf.Enabled = false
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
conf := &dhcpServerConfigJSON{}
|
||||
conf.Enabled = boolToNullBool(s.conf.Enabled)
|
||||
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
|
||||
conf.InterfaceName = s.conf.InterfaceName
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(conf)
|
||||
@@ -230,7 +231,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
|
||||
if conf.Enabled == aghalg.NBTrue && !v4Enabled && !v6Enabled {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
|
||||
|
||||
return
|
||||
@@ -243,8 +244,8 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled != nbNull {
|
||||
s.conf.Enabled = conf.Enabled == nbTrue
|
||||
if conf.Enabled != aghalg.NBNull {
|
||||
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
if conf.InterfaceName != "" {
|
||||
@@ -279,11 +280,11 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Flags string `json:"flags"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
Addrs4 []net.IP `json:"ipv4_addresses"`
|
||||
Addrs6 []net.IP `json:"ipv6_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -497,7 +498,6 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
|
||||
if ip4 == nil {
|
||||
l.IP = l.IP.To16()
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// nullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
// instead of pointers to bool.
|
||||
//
|
||||
// TODO(a.garipov): Inspect uses of *bool, move this type into some new package
|
||||
// if we need it somewhere else.
|
||||
type nullBool uint8
|
||||
|
||||
// nullBool values
|
||||
const (
|
||||
nbNull nullBool = iota
|
||||
nbTrue
|
||||
nbFalse
|
||||
)
|
||||
|
||||
// String implements the fmt.Stringer interface for nullBool.
|
||||
func (nb nullBool) String() (s string) {
|
||||
switch nb {
|
||||
case nbNull:
|
||||
return "null"
|
||||
case nbTrue:
|
||||
return "true"
|
||||
case nbFalse:
|
||||
return "false"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("!invalid nullBool %d", uint8(nb))
|
||||
}
|
||||
|
||||
// boolToNullBool converts a bool into a nullBool.
|
||||
func boolToNullBool(cond bool) (nb nullBool) {
|
||||
if cond {
|
||||
return nbTrue
|
||||
}
|
||||
|
||||
return nbFalse
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *nullBool.
|
||||
func (nb *nullBool) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
|
||||
*nb = nbNull
|
||||
} else if bytes.Equal(b, []byte("true")) {
|
||||
*nb = nbTrue
|
||||
} else if bytes.Equal(b, []byte("false")) {
|
||||
*nb = nbFalse
|
||||
} else {
|
||||
return fmt.Errorf("invalid nullBool value %q", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want nullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
data: []byte{},
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
data: []byte("true"),
|
||||
want: nbTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
data: []byte("false"),
|
||||
want: nbFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErrMsg: `invalid nullBool value "invalid"`,
|
||||
data: []byte("invalid"),
|
||||
want: nbNull,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got nullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
want := nbTrue
|
||||
var got struct {
|
||||
A nullBool
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(`{"A":true}`), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, got.A)
|
||||
})
|
||||
}
|
||||
@@ -18,17 +18,14 @@ import (
|
||||
|
||||
// The aliases for DHCP option types available for explicit declaration.
|
||||
const (
|
||||
hexTyp = "hex"
|
||||
ipTyp = "ip"
|
||||
ipsTyp = "ips"
|
||||
textTyp = "text"
|
||||
typHex = "hex"
|
||||
typIP = "ip"
|
||||
typIPs = "ips"
|
||||
typText = "text"
|
||||
typDel = "del"
|
||||
)
|
||||
|
||||
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string. For
|
||||
// example:
|
||||
//
|
||||
// 252 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
|
||||
//
|
||||
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string.
|
||||
func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var data []byte
|
||||
data, err = hex.DecodeString(s)
|
||||
@@ -39,10 +36,7 @@ func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
|
||||
return dhcpv4.OptionGeneric{Data: data}, nil
|
||||
}
|
||||
|
||||
// parseDHCPOptionIP parses a DHCP option as a single IP address. For example:
|
||||
//
|
||||
// 6 ip 192.168.1.1
|
||||
//
|
||||
// parseDHCPOptionIP parses a DHCP option as a single IP address.
|
||||
func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var ip net.IP
|
||||
// All DHCPv4 options require IPv4, so don't put the 16-byte version.
|
||||
@@ -58,10 +52,7 @@ func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
|
||||
}
|
||||
|
||||
// parseDHCPOptionIPs parses a DHCP option as a comma-separates list of IP
|
||||
// addresses. For example:
|
||||
//
|
||||
// 6 ips 192.168.1.1,192.168.1.2
|
||||
//
|
||||
// addresses.
|
||||
func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var ips dhcpv4.IPs
|
||||
var ip net.IP
|
||||
@@ -78,23 +69,53 @@ func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
|
||||
}
|
||||
|
||||
// parseDHCPOptionText parses a DHCP option as a simple UTF-8 encoded
|
||||
// text. For example:
|
||||
//
|
||||
// 252 text http://192.168.1.1/wpad.dat
|
||||
//
|
||||
// text.
|
||||
func parseDHCPOptionText(s string) (val dhcpv4.OptionValue) {
|
||||
return dhcpv4.OptionGeneric{Data: []byte(s)}
|
||||
}
|
||||
|
||||
// parseDHCPOption parses an option. See the documentation of parseDHCPOption*
|
||||
// for more info.
|
||||
// parseDHCPOptionVal parses a DHCP option value considering typ. For the del
|
||||
// option the value string is ignored. The examples of possible value pairs:
|
||||
//
|
||||
// - hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
|
||||
// - ip 192.168.1.1
|
||||
// - ips 192.168.1.1,192.168.1.2
|
||||
// - text http://192.168.1.1/wpad.dat
|
||||
// - del
|
||||
func parseDHCPOptionVal(typ, valStr string) (val dhcpv4.OptionValue, err error) {
|
||||
switch typ {
|
||||
case typHex:
|
||||
val, err = parseDHCPOptionHex(valStr)
|
||||
case typIP:
|
||||
val, err = parseDHCPOptionIP(valStr)
|
||||
case typIPs:
|
||||
val, err = parseDHCPOptionIPs(valStr)
|
||||
case typText:
|
||||
val = parseDHCPOptionText(valStr)
|
||||
case typDel:
|
||||
val = dhcpv4.OptionGeneric{Data: nil}
|
||||
default:
|
||||
err = fmt.Errorf("unknown option type %q", typ)
|
||||
}
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
// parseDHCPOption parses an option. See the documentation of
|
||||
// parseDHCPOptionVal for more info.
|
||||
func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
parts := strings.SplitN(s, " ", 3)
|
||||
if len(parts) < 3 {
|
||||
return opt, errors.Error("need at least three fields")
|
||||
|
||||
var valStr string
|
||||
if pl := len(parts); pl < 3 {
|
||||
if pl < 2 || parts[1] != typDel {
|
||||
return opt, errors.Error("bad option format")
|
||||
}
|
||||
} else {
|
||||
valStr = parts[2]
|
||||
}
|
||||
|
||||
var code64 uint64
|
||||
@@ -103,27 +124,16 @@ func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
|
||||
return opt, fmt.Errorf("parsing option code: %w", err)
|
||||
}
|
||||
|
||||
var optVal dhcpv4.OptionValue
|
||||
switch typ, val := parts[1], parts[2]; typ {
|
||||
case hexTyp:
|
||||
optVal, err = parseDHCPOptionHex(val)
|
||||
case ipTyp:
|
||||
optVal, err = parseDHCPOptionIP(val)
|
||||
case ipsTyp:
|
||||
optVal, err = parseDHCPOptionIPs(val)
|
||||
case textTyp:
|
||||
optVal = parseDHCPOptionText(val)
|
||||
default:
|
||||
return opt, fmt.Errorf("unknown option type %q", typ)
|
||||
}
|
||||
|
||||
val, err := parseDHCPOptionVal(parts[1], valStr)
|
||||
if err != nil {
|
||||
// Don't wrap an error since it's informative enough as is and there
|
||||
// also the deferred annotation.
|
||||
return opt, err
|
||||
}
|
||||
|
||||
return dhcpv4.Option{
|
||||
Code: dhcpv4.GenericOptionCode(code64),
|
||||
Value: optVal,
|
||||
Value: val,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -139,40 +149,45 @@ func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
|
||||
// See also https://datatracker.ietf.org/doc/html/rfc1122,
|
||||
// https://datatracker.ietf.org/doc/html/rfc1123, and
|
||||
// https://datatracker.ietf.org/doc/html/rfc2132.
|
||||
opts = dhcpv4.Options{
|
||||
opts = dhcpv4.OptionsFromList(
|
||||
// IP-Layer Per Host
|
||||
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
|
||||
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionNonLocalSourceRouting, []byte{0}),
|
||||
|
||||
// Set the current recommended default time to live for the
|
||||
// Internet Protocol which is 64, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc1700.
|
||||
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionDefaultIPTTL, []byte{0x40}),
|
||||
|
||||
// IP-Layer Per Interface
|
||||
|
||||
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
|
||||
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
|
||||
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionPerformMaskDiscovery, []byte{0}),
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionMaskSupplier, []byte{0}),
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionPerformRouterDiscovery, []byte{1}),
|
||||
// The all-routers address is preferred wherever possible, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
|
||||
dhcpv4.OptionRouterSolicitationAddress.Code(): netutil.IPv4allrouter(),
|
||||
dhcpv4.OptionBroadcastAddress.Code(): netutil.IPv4bcast(),
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionRouterSolicitationAddress,
|
||||
Value: dhcpv4.IP(netutil.IPv4allrouter()),
|
||||
},
|
||||
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
|
||||
|
||||
// Link-Layer Per Interface
|
||||
|
||||
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
|
||||
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionTrailerEncapsulation, []byte{0}),
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionEthernetEncapsulation, []byte{0}),
|
||||
|
||||
// TCP Per Host
|
||||
|
||||
dhcpv4.OptionTCPKeepaliveInterval.Code(): dhcpv4.Duration(0).ToBytes(),
|
||||
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionTCPKeepaliveInterval,
|
||||
Value: dhcpv4.Duration(0),
|
||||
},
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionTCPKeepaliveGarbage, []byte{0}),
|
||||
|
||||
// Values From Configuration
|
||||
|
||||
dhcpv4.OptionRouter.Code(): netutil.CloneIP(conf.subnet.IP),
|
||||
dhcpv4.OptionSubnetMask.Code(): dhcpv4.IPMask(conf.subnet.Mask).ToBytes(),
|
||||
}
|
||||
dhcpv4.OptRouter(conf.subnet.IP),
|
||||
dhcpv4.OptSubnetMask(conf.subnet.Mask),
|
||||
)
|
||||
|
||||
// Set values for explicitly configured options.
|
||||
for i, o := range conf.Options {
|
||||
|
||||
@@ -8,103 +8,111 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseOpt(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantErrMsg string
|
||||
wantOpt dhcpv4.Option
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "hex_success",
|
||||
in: "6 hex c0a80101c0a80102",
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x01},
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x02},
|
||||
name: "hex_success",
|
||||
in: "6 hex c0a80101c0a80102",
|
||||
wantOpt: dhcpv4.OptGeneric(
|
||||
dhcpv4.GenericOptionCode(6),
|
||||
[]byte{
|
||||
0xC0, 0xA8, 0x01, 0x01,
|
||||
0xC0, 0xA8, 0x01, 0x02,
|
||||
},
|
||||
),
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "ip_success",
|
||||
in: "6 ip 1.2.3.4",
|
||||
name: "ip_success",
|
||||
in: "6 ip 1.2.3.4",
|
||||
wantOpt: dhcpv4.Option{
|
||||
Code: dhcpv4.GenericOptionCode(6),
|
||||
Value: dhcpv4.IP(net.IP{0x01, 0x02, 0x03, 0x04}),
|
||||
},
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0x01, 0x02, 0x03, 0x04},
|
||||
),
|
||||
}, {
|
||||
name: "ip_fail_v6",
|
||||
in: "6 ip ::1234",
|
||||
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
|
||||
}, {
|
||||
name: "ips_success",
|
||||
in: "6 ips 192.168.1.1,192.168.1.2",
|
||||
name: "ips_success",
|
||||
in: "6 ips 192.168.1.1,192.168.1.2",
|
||||
wantOpt: dhcpv4.Option{
|
||||
Code: dhcpv4.GenericOptionCode(6),
|
||||
Value: dhcpv4.IPs([]net.IP{
|
||||
{0xC0, 0xA8, 0x01, 0x01},
|
||||
{0xC0, 0xA8, 0x01, 0x02},
|
||||
}),
|
||||
},
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x01},
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x02},
|
||||
),
|
||||
}, {
|
||||
name: "text_success",
|
||||
in: "252 text http://192.168.1.1/",
|
||||
wantErrMsg: "",
|
||||
name: "text_success",
|
||||
in: "252 text http://192.168.1.1/",
|
||||
wantOpt: dhcpv4.OptGeneric(
|
||||
dhcpv4.GenericOptionCode(252),
|
||||
[]byte("http://192.168.1.1/"),
|
||||
),
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "del_success",
|
||||
in: "61 del",
|
||||
wantOpt: dhcpv4.Option{
|
||||
Code: dhcpv4.GenericOptionCode(dhcpv4.OptionClientIdentifier),
|
||||
Value: dhcpv4.OptionGeneric{Data: nil},
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bad_parts",
|
||||
in: "6 ip",
|
||||
wantErrMsg: `invalid option string "6 ip": need at least three fields`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: `invalid option string "6 ip": bad option format`,
|
||||
}, {
|
||||
name: "bad_code",
|
||||
in: "256 ip 1.1.1.1",
|
||||
name: "bad_code",
|
||||
in: "256 ip 1.1.1.1",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: `invalid option string "256 ip 1.1.1.1": parsing option code: ` +
|
||||
`strconv.ParseUint: parsing "256": value out of range`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "bad_type",
|
||||
in: "6 bad 1.1.1.1",
|
||||
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
|
||||
}, {
|
||||
name: "hex_error",
|
||||
in: "6 hex ZZZ",
|
||||
name: "hex_error",
|
||||
in: "6 hex ZZZ",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: `invalid option string "6 hex ZZZ": decoding hex: ` +
|
||||
`encoding/hex: invalid byte: U+005A 'Z'`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "ip_error",
|
||||
in: "6 ip 1.2.3.x",
|
||||
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
|
||||
}, {
|
||||
name: "ips_error",
|
||||
in: "6 ips 192.168.1.1,192.168.1.x",
|
||||
name: "ips_error",
|
||||
in: "6 ips 192.168.1.1,192.168.1.x",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantErrMsg: "invalid option string \"6 ips 192.168.1.1,192.168.1.x\": " +
|
||||
"parsing ip at index 1: bad ipv4 address \"192.168.1.x\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
opt, err := parseDHCPOption(tc.in)
|
||||
if tc.wantErrMsg != "" {
|
||||
require.Error(t, err)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
|
||||
assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
|
||||
// assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
|
||||
// assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
|
||||
assert.Equal(t, tc.wantOpt, opt)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,51 +135,80 @@ func TestPrepareOptions(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
opts []string
|
||||
checks dhcpv4.Options
|
||||
opts []string
|
||||
}{{
|
||||
name: "all_default",
|
||||
checks: allDefault,
|
||||
opts: nil,
|
||||
}, {
|
||||
name: "configured_ip",
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
|
||||
},
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionBroadcastAddress.Code(): oneIP,
|
||||
},
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
|
||||
},
|
||||
}, {
|
||||
name: "configured_ips",
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
|
||||
},
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionDomainNameServer.Code(): append(oneIP, otherIP...),
|
||||
},
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
|
||||
},
|
||||
}, {
|
||||
name: "configured_bad",
|
||||
name: "configured_bad",
|
||||
checks: allDefault,
|
||||
opts: []string{
|
||||
"20 hex",
|
||||
"23 hex abc",
|
||||
"32 ips 1,2,3,4",
|
||||
"28 256.256.256.256",
|
||||
},
|
||||
checks: allDefault,
|
||||
}, {
|
||||
name: "configured_del",
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionBroadcastAddress.Code(): nil,
|
||||
},
|
||||
opts: []string{
|
||||
"28 del",
|
||||
},
|
||||
}, {
|
||||
name: "rewritten_del",
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionBroadcastAddress.Code(): []byte{255, 255, 255, 255},
|
||||
},
|
||||
opts: []string{
|
||||
"28 del",
|
||||
"28 ip 255.255.255.255",
|
||||
},
|
||||
}, {
|
||||
name: "configured_and_del",
|
||||
checks: dhcpv4.Options{
|
||||
123: []byte("cba"),
|
||||
},
|
||||
opts: []string{
|
||||
"123 text abc",
|
||||
"123 del",
|
||||
"123 text cba",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.name == "configured_del" {
|
||||
assert.True(t, true)
|
||||
}
|
||||
opts := prepareOptions(V4ServerConf{
|
||||
// Just to avoid nil pointer dereference.
|
||||
subnet: &net.IPNet{},
|
||||
Options: tc.opts,
|
||||
})
|
||||
for c, v := range tc.checks {
|
||||
optVal := opts.Get(dhcpv4.GenericOptionCode(c))
|
||||
require.NotNil(t, optVal)
|
||||
|
||||
assert.Len(t, optVal, len(v))
|
||||
assert.Equal(t, v, optVal)
|
||||
val := opts.Get(dhcpv4.GenericOptionCode(c))
|
||||
assert.Lenf(t, val, len(v), "Code: %v", c)
|
||||
assert.Equal(t, v, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
@@ -91,6 +94,9 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna
|
||||
|
||||
if hostname == "" {
|
||||
hostname = aghnet.GenerateHostname(ip)
|
||||
} else if s.leaseHosts.Has(hostname) {
|
||||
log.Info("dhcpv4: hostname %q already exists", hostname)
|
||||
hostname = aghnet.GenerateHostname(ip)
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(hostname)
|
||||
@@ -250,11 +256,11 @@ func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
// Remove a dynamic lease with the same properties
|
||||
// Return error if a static lease is found
|
||||
func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
for i := 0; i < len(s.leases); i++ {
|
||||
l := s.leases[i]
|
||||
for i, l := range s.leases {
|
||||
isStatic := l.IsStatic()
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
if l.IsStatic() {
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
|
||||
if isStatic {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
@@ -266,20 +272,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if l.IsStatic() {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
s.rmLeaseByIndex(i)
|
||||
if i == len(s.leases) {
|
||||
break
|
||||
}
|
||||
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if l.Hostname == lease.Hostname {
|
||||
if !isStatic && l.Hostname == lease.Hostname {
|
||||
l.Hostname = ""
|
||||
}
|
||||
}
|
||||
@@ -287,6 +280,10 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrDupHostname is returned by addLease when the added lease has a not empty
|
||||
// non-unique hostname.
|
||||
const ErrDupHostname = errors.Error("hostname is not unique")
|
||||
|
||||
// addLease adds a dynamic or static lease.
|
||||
func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
r := s.conf.ipRange
|
||||
@@ -302,13 +299,17 @@ func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
return fmt.Errorf("lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
s.leases = append(s.leases, l)
|
||||
s.leasedOffsets.set(offset, true)
|
||||
|
||||
if l.Hostname != "" {
|
||||
if s.leaseHosts.Has(l.Hostname) {
|
||||
return ErrDupHostname
|
||||
}
|
||||
|
||||
s.leaseHosts.Add(l.Hostname)
|
||||
}
|
||||
|
||||
s.leases = append(s.leases, l)
|
||||
s.leasedOffsets.set(offset, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -333,12 +334,16 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
|
||||
return errors.Error("lease not found")
|
||||
}
|
||||
|
||||
// AddStaticLease adds a static lease. It is safe for concurrent use.
|
||||
// AddStaticLease implements the DHCPServer interface for *v4Server. It is safe
|
||||
// for concurrent use.
|
||||
func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
|
||||
|
||||
if ip4 := l.IP.To4(); ip4 == nil {
|
||||
ip := l.IP.To4()
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP.Equal(ip) {
|
||||
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
|
||||
}
|
||||
|
||||
l.Expiry = time.Unix(leaseExpireStatic, 0)
|
||||
@@ -359,10 +364,11 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
return fmt.Errorf("validating hostname: %w", err)
|
||||
}
|
||||
|
||||
// Don't check for hostname uniqueness, since we try to emulate
|
||||
// dnsmasq here, which means that rmDynamicLease below will
|
||||
// simply empty the hostname of the dynamic lease if there even
|
||||
// is one.
|
||||
// Don't check for hostname uniqueness, since we try to emulate dnsmasq
|
||||
// here, which means that rmDynamicLease below will simply empty the
|
||||
// hostname of the dynamic lease if there even is one. In case a static
|
||||
// lease with the same name already exists, addLease will return an
|
||||
// error and the lease won't be added.
|
||||
|
||||
l.Hostname = hostname
|
||||
}
|
||||
@@ -377,7 +383,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf(
|
||||
"removing dynamic leases for %s (%s): %w",
|
||||
l.IP,
|
||||
ip,
|
||||
l.HWAddr,
|
||||
err,
|
||||
)
|
||||
@@ -387,7 +393,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -517,11 +523,7 @@ func (s *v4Server) findExpiredLease() int {
|
||||
// reserveLease reserves a lease for a client by its MAC-address. It returns
|
||||
// nil if it couldn't allocate a new lease.
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
l = &Lease{
|
||||
HWAddr: make([]byte, len(mac)),
|
||||
}
|
||||
|
||||
copy(l.HWAddr, mac)
|
||||
l = &Lease{HWAddr: slices.Clone(mac)}
|
||||
|
||||
l.IP = s.nextIP()
|
||||
if l.IP == nil {
|
||||
@@ -614,33 +616,25 @@ func (s *v4Server) processDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err erro
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type optFQDN struct {
|
||||
name string
|
||||
}
|
||||
// OptionFQDN returns a DHCPv4 option for sending the FQDN to the client
|
||||
// requested another hostname.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702.
|
||||
func OptionFQDN(fqdn string) (opt dhcpv4.Option) {
|
||||
optData := []byte{
|
||||
// Set only S and O DHCP client FQDN option flags.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.1.
|
||||
1<<0 | 1<<1,
|
||||
// The RCODE fields should be set to 0xFF in the server responses.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.2.
|
||||
0xFF,
|
||||
0xFF,
|
||||
}
|
||||
optData = append(optData, fqdn...)
|
||||
|
||||
func (o *optFQDN) String() string {
|
||||
return "optFQDN"
|
||||
}
|
||||
|
||||
// flags[1]
|
||||
// A-RR[1]
|
||||
// PTR-RR[1]
|
||||
// name[]
|
||||
func (o *optFQDN) ToBytes() []byte {
|
||||
b := make([]byte, 3+len(o.name))
|
||||
i := 0
|
||||
|
||||
b[i] = 0x03 // f_server_overrides | f_server
|
||||
i++
|
||||
|
||||
b[i] = 255 // A-RR
|
||||
i++
|
||||
|
||||
b[i] = 255 // PTR-RR
|
||||
i++
|
||||
|
||||
copy(b[i:], []byte(o.name))
|
||||
return b
|
||||
return dhcpv4.OptGeneric(dhcpv4.OptionFQDN, optData)
|
||||
}
|
||||
|
||||
// checkLease checks if the pair of mac and ip is already leased. The mismatch
|
||||
@@ -673,6 +667,8 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
|
||||
// processRequest is the handler for the DHCP Request request.
|
||||
func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needsReply bool) {
|
||||
mac := req.ClientHWAddr
|
||||
// TODO(e.burkov): The IP address can only be requested in DHCPDISCOVER
|
||||
// message.
|
||||
reqIP := req.RequestedIPAddress()
|
||||
if reqIP == nil {
|
||||
reqIP = req.ClientIPAddr
|
||||
@@ -705,24 +701,17 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needs
|
||||
if !lease.IsStatic() {
|
||||
cliHostname := req.HostName()
|
||||
hostname := s.validHostnameForClient(cliHostname, reqIP)
|
||||
if hostname != lease.Hostname && s.leaseHosts.Has(hostname) {
|
||||
log.Info("dhcpv4: hostname %q already exists", hostname)
|
||||
lease.Hostname = ""
|
||||
} else {
|
||||
if lease.Hostname != hostname {
|
||||
lease.Hostname = hostname
|
||||
resp.UpdateOption(dhcpv4.OptHostName(hostname))
|
||||
}
|
||||
|
||||
s.commitLease(lease)
|
||||
} else if lease.Hostname != "" {
|
||||
o := &optFQDN{
|
||||
name: lease.Hostname,
|
||||
}
|
||||
fqdn := dhcpv4.Option{
|
||||
Code: dhcpv4.OptionFQDN,
|
||||
Value: o,
|
||||
}
|
||||
|
||||
resp.UpdateOption(fqdn)
|
||||
// TODO(e.burkov): This option is used to update the server's DNS
|
||||
// mapping. The option should only be answered when it has been
|
||||
// requested.
|
||||
resp.UpdateOption(OptionFQDN(lease.Hostname))
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
|
||||
@@ -845,7 +834,7 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
|
||||
|
||||
// TODO(a.garipov): Refactor this into handlers.
|
||||
var l *Lease
|
||||
switch req.MessageType() {
|
||||
switch mt := req.MessageType(); mt {
|
||||
case dhcpv4.MessageTypeDiscover:
|
||||
l, err = s.processDiscover(req, resp)
|
||||
if err != nil {
|
||||
@@ -886,28 +875,32 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
|
||||
resp.YourIPAddr = netutil.CloneIP(l.IP)
|
||||
}
|
||||
|
||||
// Set IP address lease time for all DHCPOFFER messages and DHCPACK
|
||||
// messages replied for DHCPREQUEST.
|
||||
// Set IP address lease time for all DHCPOFFER messages and DHCPACK messages
|
||||
// replied for DHCPREQUEST.
|
||||
//
|
||||
// TODO(e.burkov): Inspect why this is always set to configured value.
|
||||
resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime))
|
||||
|
||||
// Delete options explicitly configured to be removed.
|
||||
for code := range resp.Options {
|
||||
if val, ok := s.options[code]; ok && val == nil {
|
||||
delete(resp.Options, code)
|
||||
}
|
||||
}
|
||||
|
||||
// Update values for each explicitly configured parameter requested by
|
||||
// client.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2131#section-4.3.1.
|
||||
requested := req.ParameterRequestList()
|
||||
for _, code := range requested {
|
||||
if configured := s.options; configured.Has(code) {
|
||||
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
|
||||
if val := s.options.Get(code); val != nil {
|
||||
resp.UpdateOption(dhcpv4.Option{
|
||||
Code: code,
|
||||
Value: dhcpv4.OptionGeneric{Data: s.options.Get(code)},
|
||||
})
|
||||
}
|
||||
}
|
||||
// Update the value of Domain Name Server option separately from others if
|
||||
// not assigned yet since its value is set after server's creating.
|
||||
if requested.Has(dhcpv4.OptionDomainNameServer) &&
|
||||
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
|
||||
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
@@ -935,6 +928,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
|
||||
resp, err := dhcpv4.NewReplyFromRequest(req)
|
||||
if err != nil {
|
||||
log.Debug("dhcpv4: dhcpv4.New: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1031,6 +1025,11 @@ func (s *v4Server) Start() (err error) {
|
||||
// No available IP addresses which may appear later.
|
||||
return nil
|
||||
}
|
||||
// Update the value of Domain Name Server option separately from others if
|
||||
// not assigned yet since its value is available only at server's start.
|
||||
if !s.options.Has(dhcpv4.OptionDomainNameServer) {
|
||||
s.options.Update(dhcpv4.OptDNS(dnsIPAddrs...))
|
||||
}
|
||||
|
||||
s.conf.dnsIPAddrs = dnsIPAddrs
|
||||
|
||||
|
||||
@@ -4,16 +4,29 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultRangeStart = net.IP{192, 168, 10, 100}
|
||||
DefaultRangeEnd = net.IP{192, 168, 10, 200}
|
||||
DefaultGatewayIP = net.IP{192, 168, 10, 1}
|
||||
DefaultSelfIP = net.IP{192, 168, 10, 2}
|
||||
DefaultSubnetMask = net.IP{255, 255, 255, 0}
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
@@ -24,11 +37,12 @@ func notify4(flags uint32) {
|
||||
func defaultV4ServerConf() (conf V4ServerConf) {
|
||||
return V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
RangeStart: DefaultRangeStart,
|
||||
RangeEnd: DefaultRangeEnd,
|
||||
GatewayIP: DefaultGatewayIP,
|
||||
SubnetMask: DefaultSubnetMask,
|
||||
notify: notify4,
|
||||
dnsIPAddrs: []net.IP{DefaultSelfIP},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,44 +58,228 @@ func defaultSrv(t *testing.T) (s DHCPServer) {
|
||||
return s
|
||||
}
|
||||
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
func TestV4Server_leasing(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := net.IP{192, 168, 10, 10}
|
||||
anotherIP := DefaultRangeStart
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
t.Run("add_static", func(t *testing.T) {
|
||||
err := s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrDupHostname)
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
wantErrMsg := "dhcpv4: adding static lease: removing " +
|
||||
"dynamic leases for " + anotherIP.String() +
|
||||
" (" + staticMAC.String() + "): static lease already exists"
|
||||
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: staticMAC,
|
||||
IP: anotherIP,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
wantErrMsg := "dhcpv4: adding static lease: removing " +
|
||||
"dynamic leases for " + staticIP.String() +
|
||||
" (" + anotherMAC.String() + "): static lease already exists"
|
||||
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: staticIP,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("add_dynamic", func(t *testing.T) {
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
discoverAnOffer := func(
|
||||
t *testing.T,
|
||||
name string,
|
||||
ip net.IP,
|
||||
mac net.HardwareAddr,
|
||||
) (resp *dhcpv4.DHCPv4) {
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.ResetLeases(s.GetLeases(LeasesStatic))
|
||||
})
|
||||
|
||||
req, err := dhcpv4.NewDiscovery(
|
||||
mac,
|
||||
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(ip)),
|
||||
dhcpv4.WithOption(dhcpv4.OptClientIdentifier([]byte{1, 2, 3, 4, 5, 6, 8})),
|
||||
dhcpv4.WithGatewayIP(DefaultGatewayIP),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp = &dhcpv4.DHCPv4{}
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
require.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
|
||||
resp.ClientHWAddr = mac
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(staticName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
fqdnOptData := resp.Options.Get(dhcpv4.OptionFQDN)
|
||||
require.Len(t, fqdnOptData, 3+len(staticName))
|
||||
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
|
||||
|
||||
assert.Equal(t, staticIP, resp.YourIPAddr)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, staticIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.NotEqual(t, staticIP, resp.YourIPAddr)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
s := defaultSrv(t)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
require.Empty(t, ls)
|
||||
|
||||
// Add static lease.
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
testCases := []struct {
|
||||
lease *Lease
|
||||
name string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
lease: &Lease{
|
||||
Hostname: "success.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "probably-router.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: DefaultGatewayIP,
|
||||
},
|
||||
name: "with_gateway_ip",
|
||||
wantErrMsg: "dhcpv4: adding static lease: " +
|
||||
"can't assign the gateway IP 192.168.10.1 to the lease",
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "ip6.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.ParseIP("ffff::1"),
|
||||
},
|
||||
name: "ipv6",
|
||||
wantErrMsg: `dhcpv4: adding static lease: ` +
|
||||
`invalid ip "ffff::1", only ipv4 is supported`,
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-mac.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "bad_mac",
|
||||
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
|
||||
`bad mac address length 2, allowed: [6 8 20]`,
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-lbl-.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "bad_hostname",
|
||||
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
|
||||
`bad domain name "bad-lbl-.local": ` +
|
||||
`bad domain name label "bad-lbl-": bad domain name label rune '-'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := s.AddStaticLease(tc.lease)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
if tc.wantErrMsg != "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: tc.lease.IP,
|
||||
HWAddr: tc.lease.HWAddr,
|
||||
})
|
||||
diffErrMsg := fmt.Sprintf("dhcpv4: lease for ip %s is different: %+v", tc.lease.IP, tc.lease)
|
||||
testutil.AssertErrorMsg(t, diffErrMsg, err)
|
||||
|
||||
// Remove static lease.
|
||||
err = s.RemoveStaticLease(tc.lease)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
require.Emptyf(t, ls, "after %s", tc.name)
|
||||
}
|
||||
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.AddStaticLease(l)
|
||||
assert.Error(t, err)
|
||||
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
assert.True(t, ls[0].IsStatic())
|
||||
|
||||
// Try to remove static lease.
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: net.IP{192, 168, 10, 110},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Remove static lease.
|
||||
err = s.RemoveStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
}
|
||||
|
||||
func TestV4_AddReplace(t *testing.T) {
|
||||
@@ -147,6 +345,8 @@ func TestV4Server_Process_optionsPriority(t *testing.T) {
|
||||
stringutil.WriteToBuilder(b, ",", ip.String())
|
||||
}
|
||||
conf.Options = []string{b.String()}
|
||||
} else {
|
||||
defer func() { s.options.Update(dhcpv4.OptDNS(defaultIP)) }()
|
||||
}
|
||||
|
||||
ss, err := v4Create(conf)
|
||||
@@ -209,6 +409,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
s.options.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
@@ -297,6 +498,7 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
s.options.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
@@ -214,7 +214,7 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
merged := allowed.Merge(disallowed)
|
||||
err = merged.Validate(aghalg.StringIsBefore)
|
||||
err = merged.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
|
||||
}
|
||||
@@ -223,13 +223,13 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
// validateStrUniq returns an informative error if clients are not unique.
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
|
||||
uc = make(aghalg.UniqChecker, len(clients))
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error) {
|
||||
uc = make(aghalg.UniqChecker[string], len(clients))
|
||||
for _, c := range clients {
|
||||
uc.Add(c)
|
||||
}
|
||||
|
||||
return uc, uc.Validate(aghalg.StringIsBefore)
|
||||
return uc, uc.Validate()
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -65,7 +65,7 @@ func clientIDFromClientServerName(
|
||||
return "", err
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
@@ -104,7 +104,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@@ -112,8 +112,8 @@ type tlsConn interface {
|
||||
ConnectionState() (cs tls.ConnectionState)
|
||||
}
|
||||
|
||||
// quicSession is a narrow interface for quic.Session to simplify testing.
|
||||
type quicSession interface {
|
||||
// quicConnection is a narrow interface for quic.Connection to simplify testing.
|
||||
type quicConnection interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
@@ -148,16 +148,16 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
conn, ok := pctx.QUICConnection.(quicConnection)
|
||||
if !ok {
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
pctx.QUICConnection,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
cliSrvName = conn.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
|
||||
@@ -29,17 +29,18 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
// testQUICSession is a quicSession for tests.
|
||||
type testQUICSession struct {
|
||||
// Session is embedded here simply to make testQUICSession a quic.Session
|
||||
// without actually implementing all methods.
|
||||
quic.Session
|
||||
// testQUICConnection is a quicConnection for tests.
|
||||
type testQUICConnection struct {
|
||||
// Connection is embedded here simply to make testQUICConnection a
|
||||
// quic.Connection without actually implementing all methods.
|
||||
quic.Connection
|
||||
|
||||
serverName string
|
||||
}
|
||||
|
||||
// ConnectionState implements the quicSession interface for testQUICSession.
|
||||
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
// ConnectionState implements the quicConnection interface for
|
||||
// testQUICConnection.
|
||||
func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
|
||||
cs.TLS.ServerName = c.serverName
|
||||
|
||||
return cs
|
||||
@@ -143,6 +144,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_case",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_case",
|
||||
proto: proxy.ProtoQUIC,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -163,17 +180,17 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var qs quic.Session
|
||||
var qconn quic.Connection
|
||||
if tc.proto == proxy.ProtoQUIC {
|
||||
qs = testQUICSession{
|
||||
qconn = testQUICConnection{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICConnection: qconn,
|
||||
}
|
||||
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
@@ -210,6 +227,11 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid_case",
|
||||
path: "/dns-query/InSeNsItIvE",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -122,6 +122,7 @@ type FilteringConfig struct {
|
||||
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
|
||||
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
|
||||
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
|
||||
HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests
|
||||
|
||||
// IpsetList is the ipset configuration that allows AdGuard Home to add
|
||||
// IP addresses of the specified domain names to an ipset list. Syntax:
|
||||
@@ -133,8 +134,9 @@ type FilteringConfig struct {
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
|
||||
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
@@ -151,7 +153,7 @@ type TLSConfig struct {
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// ServerName is the hostname of the server. Currently, it is only being
|
||||
// used for ClientID checking.
|
||||
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate
|
||||
@@ -191,7 +193,7 @@ type ServerConfig struct {
|
||||
ConfigModified func()
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
||||
ResolveClients bool
|
||||
@@ -276,6 +278,11 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
return proxyConfig, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
defaultParentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
// is configured
|
||||
func (s *Server) initDefaultSettings() {
|
||||
@@ -287,12 +294,12 @@ func (s *Server) initDefaultSettings() {
|
||||
s.conf.BootstrapDNS = defaultBootstrap
|
||||
}
|
||||
|
||||
if len(s.conf.ParentalBlockHost) == 0 {
|
||||
s.conf.ParentalBlockHost = parentalBlockHost
|
||||
if s.conf.ParentalBlockHost == "" {
|
||||
s.conf.ParentalBlockHost = defaultParentalBlockHost
|
||||
}
|
||||
|
||||
if len(s.conf.SafeBrowsingBlockHost) == 0 {
|
||||
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
|
||||
if s.conf.SafeBrowsingBlockHost == "" {
|
||||
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
|
||||
}
|
||||
|
||||
if s.conf.UDPListenAddrs == nil {
|
||||
|
||||
@@ -76,6 +76,10 @@ const (
|
||||
resultCodeError
|
||||
)
|
||||
|
||||
// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests.
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
const ddrHostFQDN = "_dns.resolver.arpa."
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{
|
||||
@@ -94,10 +98,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
mods := []modProcessFunc{
|
||||
s.processRecursion,
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDetermineLocal,
|
||||
s.processInternalHosts,
|
||||
s.processDHCPHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@@ -135,7 +140,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -226,12 +230,10 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
)
|
||||
}
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
ip := netutil.CloneIP(l.IP)
|
||||
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
ipToHost.Set(ip, lowhost)
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
@@ -242,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
s.setTableIPToHost(ipToHost)
|
||||
}
|
||||
|
||||
// processDDRQuery responds to SVCB query for a special use domain name
|
||||
// ‘_dns.resolver.arpa’. The result contains different types of encryption
|
||||
// supported by current user configuration.
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
question := d.Req.Question[0]
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if question.Name == ddrHostFQDN {
|
||||
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
|
||||
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
|
||||
d.Res = s.makeResponse(d.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
d.Res = s.makeDDRResponse(d.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// makeDDRResponse creates DDR answer according to server configuration. The
|
||||
// contructed SVCB resource records have the priority of 1 for each entry,
|
||||
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
//
|
||||
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
// TODO(e.burkov): Think about storing the FQDN version of the server's
|
||||
// name somewhere.
|
||||
domainName := dns.Fqdn(s.conf.ServerName)
|
||||
|
||||
for _, addr := range s.conf.HTTPSListenAddrs {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.TLSListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.QUICListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// processDetermineLocal determines if the client's IP address is from
|
||||
// locally-served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
@@ -252,7 +346,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return rc
|
||||
}
|
||||
|
||||
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
|
||||
dctx.isLocalClient = s.privateNets.Contains(ip)
|
||||
|
||||
return rc
|
||||
}
|
||||
@@ -280,11 +374,11 @@ func (s *Server) hostToIP(host string) (ip net.IP, ok bool) {
|
||||
return ip, true
|
||||
}
|
||||
|
||||
// processInternalHosts respond to A requests if the target hostname is known to
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server.
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
if !s.dhcpServer.Enabled() {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -299,11 +393,10 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
reqHost := strings.ToLower(q.Name)
|
||||
reqHost := strings.ToLower(q.Name[:len(q.Name)-1])
|
||||
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
|
||||
// server.
|
||||
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
|
||||
if host == reqHost {
|
||||
if !strings.HasSuffix(reqHost, s.localDomainSuffix) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -316,7 +409,7 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
ip, ok := s.hostToIP(host)
|
||||
ip, ok := s.hostToIP(reqHost)
|
||||
if !ok {
|
||||
// TODO(e.burkov): Inspect special cases when user want to apply some
|
||||
// rules handled by other processors to the hosts with TLD.
|
||||
@@ -373,8 +466,8 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
// don't need to be accessible externally.
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -413,7 +506,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
if !ok {
|
||||
return "", false
|
||||
@@ -430,7 +523,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
|
||||
// Respond to PTR requests if the target IP is leased by our DHCP server and the
|
||||
// requestor is inside the local network.
|
||||
func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -481,7 +574,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,212 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
s := &Server{
|
||||
subnetDetector: snd,
|
||||
const (
|
||||
ddrTestDomainName = "dns.example.net"
|
||||
ddrTestFQDN = ddrTestDomainName + "."
|
||||
)
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: 8044},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
},
|
||||
}
|
||||
|
||||
dotSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: 8043},
|
||||
},
|
||||
}
|
||||
|
||||
doqSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: 8042},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want []*dns.SVCB
|
||||
wantRes resultCode
|
||||
portDoH int
|
||||
portDoT int
|
||||
portDoQ int
|
||||
qtype uint16
|
||||
ddrEnabled bool
|
||||
}{{
|
||||
name: "pass_host",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: "example.net.",
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "pass_qtype",
|
||||
wantRes: resultCodeFinish,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeA,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "pass_disabled_tls",
|
||||
wantRes: resultCodeFinish,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
}, {
|
||||
name: "pass_disabled_ddr",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: false,
|
||||
portDoH: 8043,
|
||||
}, {
|
||||
name: "dot",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dotSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
}, {
|
||||
name: "doh",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dohSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8044,
|
||||
}, {
|
||||
name: "doq",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{doqSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoQ: 8042,
|
||||
}, {
|
||||
name: "dot_doh",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{dotSVCB, dohSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoT: 8043,
|
||||
portDoH: 8044,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
|
||||
|
||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
}
|
||||
|
||||
res := s.processDDRQuery(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
|
||||
if tc.wantRes != resultCodeFinish {
|
||||
return
|
||||
}
|
||||
|
||||
msg := dctx.proxyCtx.Res
|
||||
require.NotNil(t, msg)
|
||||
|
||||
for _, v := range tc.want {
|
||||
v.Hdr = s.hdr(req, dns.TypeSVCB)
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, tc.want, msg.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
proxyConf := proxy.Config{}
|
||||
|
||||
if portDoT > 0 {
|
||||
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
s = &Server{
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxyConf,
|
||||
},
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
HandleDDR: ddrEnabled,
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if portDoH > 0 {
|
||||
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliIP net.IP
|
||||
want bool
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliIP: net.IP{192, 168, 0, 1},
|
||||
want: true,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliIP: net.IP{250, 249, 0, 1},
|
||||
want: false,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliIP: net.IP{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "nil",
|
||||
cliIP: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -47,12 +224,12 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
assert.Equal(t, tc.want, dctx.isLocalClient)
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -93,7 +270,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: defaultLocalDomainSuffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
"example." + defaultLocalDomainSuffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -115,7 +292,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
isLocalClient: tc.isLocalCli,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
pctx := dctx.proxyCtx
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
@@ -141,10 +318,10 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
const (
|
||||
examplecom = "example.com"
|
||||
examplelan = "example.lan"
|
||||
examplelan = "example." + defaultLocalDomainSuffix
|
||||
)
|
||||
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
@@ -193,41 +370,41 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
}, {
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
suffix: ".custom.",
|
||||
suffix: "custom",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + tc.suffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
|
||||
@@ -33,11 +33,6 @@ const DefaultTimeout = 10 * time.Second
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
var defaultDNS = []string{
|
||||
"https://dns10.quad9.net/dns-query",
|
||||
}
|
||||
@@ -66,7 +61,7 @@ type Server struct {
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Stats
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
@@ -74,7 +69,7 @@ type Server struct {
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
privateNets netutil.SubnetSet
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
recDetector *recursionDetector
|
||||
@@ -107,28 +102,17 @@ type Server struct {
|
||||
// when no suffix is provided.
|
||||
//
|
||||
// See the documentation for Server.localDomainSuffix.
|
||||
const defaultLocalDomainSuffix = ".lan."
|
||||
const defaultLocalDomainSuffix = "lan"
|
||||
|
||||
// DNSCreateParams are parameters to create a new server.
|
||||
type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
SubnetDetector *aghnet.SubnetDetector
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
// domainNameToSuffix converts a domain name into a local domain suffix.
|
||||
func domainNameToSuffix(tld string) (suffix string) {
|
||||
l := len(tld) + 2
|
||||
b := make([]byte, l)
|
||||
b[0] = '.'
|
||||
copy(b[1:], tld)
|
||||
b[l-1] = '.'
|
||||
|
||||
return string(b)
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Interface
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -151,7 +135,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
return nil, fmt.Errorf("local domain: %w", err)
|
||||
}
|
||||
|
||||
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
|
||||
localDomainSuffix = p.LocalDomain
|
||||
}
|
||||
|
||||
if p.Anonymizer == nil {
|
||||
@@ -161,7 +145,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
subnetDetector: p.SubnetDetector,
|
||||
privateNets: p.PrivateNets,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
@@ -173,7 +157,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
// passes the public testing.
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil)
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
@@ -314,14 +298,16 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
resolver := s.internalProxy
|
||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
var resolver *proxy.Proxy
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
s.recDetector.add(*req)
|
||||
} else {
|
||||
resolver = s.internalProxy
|
||||
}
|
||||
|
||||
if err = resolver.Resolve(ctx); err != nil {
|
||||
|
||||
@@ -17,13 +17,14 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -69,14 +70,11 @@ func createTestServer(
|
||||
f := filtering.New(filterConf, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -860,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
const hostname = "wmconvirus.narod.ru"
|
||||
|
||||
sbUps := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: true,
|
||||
}
|
||||
sbUps := aghtest.NewBlockUpstream(hostname, true)
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
@@ -913,15 +903,10 @@ func TestRewrite(t *testing.T) {
|
||||
f := filtering.New(c, nil)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1000,7 +985,7 @@ func TestRewrite(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
func publicKey(priv any) any {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
@@ -1028,36 +1013,33 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
const localDomain = "lan"
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
s.conf.ProtectionEnabled = true
|
||||
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "%s", addr)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
@@ -1066,7 +1048,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "myhost.", ptr.Ptr)
|
||||
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
@@ -1105,16 +1087,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
}, nil)
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1197,25 +1174,48 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
localDomainHost = "local.domain"
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = net.IP{1, 1, 1, 1}
|
||||
localIP = net.IP{192, 168, 1, 1}
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
|
||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
upstreamErr := errors.Error("upstream error")
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: upstreamErr,
|
||||
}
|
||||
nonPtrUpstream := &aghtest.TestBlockUpstream{
|
||||
Hostname: "some-host",
|
||||
Block: true,
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
||||
|
||||
srv := NewCustomServer(&proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
@@ -1227,11 +1227,8 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
var err error
|
||||
srv.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
@@ -1240,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
req net.IP
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: "one.one.one.one",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: net.IP{1, 1, 1, 1},
|
||||
req: onesIP,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: "local.domain",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: upstreamErr,
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
Preference: 32,
|
||||
}
|
||||
svcbVal := &rules.DNSSVCB{
|
||||
Params: map[string]string{"alpn": "h3"},
|
||||
Params: map[string]string{"alpn": "h3", "dohpath": "/dns-query"},
|
||||
Target: dns.Fqdn(domain),
|
||||
Priority: 32,
|
||||
}
|
||||
@@ -164,10 +164,20 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
|
||||
require.Len(t, d.Res.Answer, 1)
|
||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
|
||||
require.True(t, ok)
|
||||
require.Len(t, ans.Value, 2)
|
||||
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
|
||||
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
|
||||
)
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
|
||||
[]string{ans.Value[0].String(), ans.Value[1].String()},
|
||||
)
|
||||
assert.Equal(t, svcbVal.Target, ans.Target)
|
||||
assert.Equal(t, svcbVal.Priority, ans.Priority)
|
||||
})
|
||||
@@ -186,8 +196,18 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
|
||||
require.Len(t, ans.Value, 2)
|
||||
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
|
||||
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
|
||||
)
|
||||
assert.ElementsMatch(
|
||||
t,
|
||||
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
|
||||
[]string{ans.Value[0].String(), ans.Value[1].String()},
|
||||
)
|
||||
assert.Equal(t, svcbVal.Target, ans.Target)
|
||||
assert.Equal(t, svcbVal.Priority, ans.Priority)
|
||||
})
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -18,6 +16,8 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type dnsConfig struct {
|
||||
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
@@ -176,7 +176,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(s.subnetDetector)
|
||||
err = req.validate(s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -350,17 +350,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// LocalNetChecker is used to check if the IP address belongs to a local
|
||||
// network.
|
||||
type LocalNetChecker interface {
|
||||
// IsLocallyServedNetwork returns true if ip is contained in any of address
|
||||
// registries defined by RFC 6303.
|
||||
IsLocallyServedNetwork(ip net.IP) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
@@ -375,6 +364,21 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
var ups string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = validateUpstream(ups, domains)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
@@ -385,13 +389,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -405,25 +402,11 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. lnc must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -433,9 +416,11 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
@@ -444,7 +429,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
continue
|
||||
}
|
||||
|
||||
if !lnc.IsLocallyServedNetwork(subnet.IP) {
|
||||
if !privateNets.Contains(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
@@ -461,16 +446,14 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
|
||||
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
|
||||
|
||||
func validateUpstream(u string) (useDefault bool, err error) {
|
||||
// Check if the user tries to specify upstream for domain.
|
||||
var isDomainSpec bool
|
||||
u, isDomainSpec, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
return !isDomainSpec, err
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
@@ -497,12 +480,14 @@ func validateUpstream(u string) (useDefault bool, err error) {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// isDomainSpec is true when the upstream is domains-specific.
|
||||
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, false, nil
|
||||
return upstreamStr, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
@@ -510,39 +495,46 @@ func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, e
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return "", false, errors.Error("missing separator")
|
||||
return "", nil, errors.Error("missing separator")
|
||||
default:
|
||||
return "", true, errors.Error("duplicated separator")
|
||||
return "", []string{}, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
var domains string
|
||||
domains, upstream = parts[0], parts[1]
|
||||
for i, host := range strings.Split(domains, "/") {
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(host)
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return upstream, true, nil
|
||||
return parts[1], domains, nil
|
||||
}
|
||||
|
||||
// excFunc is a signature of function to check if upstream exchanges correctly.
|
||||
type excFunc func(u upstream.Upstream) (err error)
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
const testTLD = "test."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "google-public-dns-a.google.com.",
|
||||
Name: testTLD,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -552,12 +544,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
reply, err = u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
if len(reply.Answer) != 1 {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -565,14 +553,22 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
const inAddrArpaTLD = "in-addr.arpa."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "1.0.0.127.in-addr.arpa.",
|
||||
Name: inAddrArpaTLD,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -585,46 +581,66 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) {
|
||||
if IsCommentOrEmpty(input) {
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// checkDNS checks the upstream server defined by upstreamConfigStr using
|
||||
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
|
||||
// upstream's address.
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
if IsCommentOrEmpty(upstreamConfigStr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Separate upstream from domains list.
|
||||
var useDefault bool
|
||||
if useDefault, err = validateUpstream(input); err != nil {
|
||||
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// No need to check this DNS server.
|
||||
if !useDefault {
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
if input, _, err = separateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
if len(bootstrap) == 0 {
|
||||
bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
log.Debug("checking if upstream %s works", input)
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", input, err)
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
if err = ef(u); err != nil {
|
||||
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
if domains != nil {
|
||||
return domainSpecificTestError{error: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("upstream %s is ok", input)
|
||||
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -647,6 +663,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Info("%v", err)
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -662,6 +681,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// above, we rewriting the error for it. These cases should be
|
||||
// handled properly instead.
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,7 +34,7 @@ func (fsr *fakeSystemResolvers) Get() (rs []string) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
|
||||
func loadTestData(t *testing.T, casesFileName string, cases any) {
|
||||
t.Helper()
|
||||
|
||||
var f *os.File
|
||||
@@ -184,7 +185,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
|
||||
wantSet: `validating upstream servers: ` +
|
||||
`validating upstream "!!!": bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
@@ -255,112 +257,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantDef assert.BoolAssertionFunc
|
||||
name string
|
||||
upstream string
|
||||
wantErr string
|
||||
}{{
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
wantErr: `wrong protocol`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_host",
|
||||
upstream: "udp://dns.google",
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_ip",
|
||||
upstream: "udp://8.8.8.8",
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "idna",
|
||||
upstream: "[/пример.рф/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "bad_domain",
|
||||
upstream: "[/!/]8.8.8.8",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
tc.wantDef(t, defaultUpstream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -375,7 +271,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
wantErr: ``,
|
||||
set: []string{"# comment"},
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
name: "no_default",
|
||||
wantErr: `no default upstreams specified`,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -385,7 +281,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
name: "with_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -397,8 +293,46 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
|
||||
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
set: []string{"1.2.3.4.5"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
set: []string{"[host.ru]#"},
|
||||
}, {
|
||||
name: "valid_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
set: []string{"[/!/]8.8.8.8"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -410,8 +344,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -452,7 +385,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = ValidateUpstreamsPrivate(set, snd)
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg dns.Msg
|
||||
want bool
|
||||
want int
|
||||
}{{
|
||||
name: "simple",
|
||||
msg: dns.Msg{
|
||||
@@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) {
|
||||
Qtype: dns.TypeA,
|
||||
}},
|
||||
},
|
||||
want: true,
|
||||
want: 1,
|
||||
}, {
|
||||
name: "unencumbered",
|
||||
msg: dns.Msg{},
|
||||
want: false,
|
||||
want: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(rd.clear)
|
||||
|
||||
rd.add(tc.msg)
|
||||
|
||||
if tc.want {
|
||||
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
|
||||
} else {
|
||||
assert.Zero(t, rd.recentRequests.Stats().Count)
|
||||
}
|
||||
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,9 +64,9 @@ func (s *Server) logQuery(
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: dctx.origResp,
|
||||
Result: dctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientID: dctx.clientID,
|
||||
ClientIP: ip,
|
||||
Elapsed: elapsed,
|
||||
AuthenticatedData: dctx.responseAD,
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
|
||||
type testStats struct {
|
||||
// Stats is embedded here simply to make testStats a stats.Stats without
|
||||
// actually implementing all methods.
|
||||
stats.Stats
|
||||
stats.Interface
|
||||
|
||||
lastEntry stats.Entry
|
||||
}
|
||||
|
||||
@@ -32,12 +32,16 @@ func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTT
|
||||
// github.com/miekg/dns module.
|
||||
var strToSVCBKey = map[string]dns.SVCBKey{
|
||||
"alpn": dns.SVCB_ALPN,
|
||||
"echconfig": dns.SVCB_ECHCONFIG,
|
||||
"ech": dns.SVCB_ECHCONFIG,
|
||||
"ipv4hint": dns.SVCB_IPV4HINT,
|
||||
"ipv6hint": dns.SVCB_IPV6HINT,
|
||||
"mandatory": dns.SVCB_MANDATORY,
|
||||
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
|
||||
"port": dns.SVCB_PORT,
|
||||
|
||||
// TODO(a.garipov): This is the previous name for the parameter that has
|
||||
// since been changed. Remove this in v0.109.0.
|
||||
"echconfig": dns.SVCB_ECHCONFIG,
|
||||
}
|
||||
|
||||
// svcbKeyHandler is a handler for one SVCB parameter key.
|
||||
@@ -51,10 +55,10 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
}
|
||||
},
|
||||
|
||||
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
"ech": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
ech, err := base64.StdEncoding.DecodeString(valStr)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
|
||||
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -119,6 +123,32 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
Port: uint16(port64),
|
||||
}
|
||||
},
|
||||
|
||||
// TODO(a.garipov): This is the previous name for the parameter that has
|
||||
// since been changed. Remove this in v0.109.0.
|
||||
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
log.Info(
|
||||
`warning: svcb/https record parameter name "echconfig" is deprecated; ` +
|
||||
`use "ech" instead`,
|
||||
)
|
||||
|
||||
ech, err := base64.StdEncoding.DecodeString(valStr)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBECHConfig{
|
||||
ECH: ech,
|
||||
}
|
||||
},
|
||||
|
||||
"dohpath": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
return &dns.SVCBDoHPath{
|
||||
Template: valStr,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// genAnswerSVCB returns a properly initialized SVCB resource record.
|
||||
|
||||
@@ -87,14 +87,18 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
svcb: dnssvcb("alpn", "h3"),
|
||||
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
|
||||
name: "alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("ech", "AAAA"),
|
||||
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
|
||||
name: "ech",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "AAAA"),
|
||||
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
|
||||
name: "echconfig",
|
||||
name: "ech_deprecated",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "%BAD%"),
|
||||
want: wantsvcb(nil),
|
||||
name: "echconfig_invalid",
|
||||
name: "ech_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
|
||||
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
|
||||
@@ -123,6 +127,10 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
svcb: dnssvcb("no-default-alpn", ""),
|
||||
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
|
||||
name: "no_default_alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("dohpath", "/dns-query"),
|
||||
want: wantsvcb(&dns.SVCBDoHPath{Template: "/dns-query"}),
|
||||
name: "dohpath",
|
||||
}, {
|
||||
svcb: dnssvcb("port", "8080"),
|
||||
want: wantsvcb(&dns.SVCBPort{Port: 8080}),
|
||||
|
||||
@@ -7,87 +7,181 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var serviceRules map[string][]*rules.NetworkRule // service name -> filtering rules
|
||||
|
||||
// svc represents a single blocked service.
|
||||
type svc struct {
|
||||
name string
|
||||
rules []string
|
||||
}
|
||||
|
||||
// servicesData contains raw blocked service data.
|
||||
//
|
||||
// Keep in sync with:
|
||||
// client/src/helpers/constants.js
|
||||
// client/src/components/ui/Icons.js
|
||||
var serviceRulesArray = []svc{
|
||||
{"whatsapp", []string{"||whatsapp.net^", "||whatsapp.com^"}},
|
||||
{"facebook", []string{
|
||||
// - client/src/helpers/constants.js
|
||||
// - client/src/components/ui/Icons.js
|
||||
var servicesData = []svc{{
|
||||
name: "whatsapp",
|
||||
rules: []string{
|
||||
"||wa.me^",
|
||||
"||whatsapp.com^",
|
||||
"||whatsapp.net^",
|
||||
},
|
||||
}, {
|
||||
name: "facebook",
|
||||
rules: []string{
|
||||
"||facebook.com^",
|
||||
"||facebook.net^",
|
||||
"||fbcdn.net^",
|
||||
"||accountkit.com^",
|
||||
"||fb.me^",
|
||||
"||fb.com^",
|
||||
"||fb.gg^",
|
||||
"||fbsbx.com^",
|
||||
"||fbwat.ch^",
|
||||
"||messenger.com^",
|
||||
"||facebookcorewwwi.onion^",
|
||||
"||fbcdn.com^",
|
||||
"||fb.watch^",
|
||||
}},
|
||||
{"twitter", []string{"||twitter.com^", "||twttr.com^", "||t.co^", "||twimg.com^"}},
|
||||
{"youtube", []string{
|
||||
"||youtube.com^",
|
||||
"||ytimg.com^",
|
||||
"||youtu.be^",
|
||||
},
|
||||
}, {
|
||||
name: "twitter",
|
||||
rules: []string{
|
||||
"||t.co^",
|
||||
"||twimg.com^",
|
||||
"||twitter.com^",
|
||||
"||twttr.com^",
|
||||
},
|
||||
}, {
|
||||
name: "youtube",
|
||||
rules: []string{
|
||||
"||googlevideo.com^",
|
||||
"||youtubei.googleapis.com^",
|
||||
"||youtube-nocookie.com^",
|
||||
"||wide-youtube.l.google.com^",
|
||||
"||youtu.be^",
|
||||
"||youtube",
|
||||
}},
|
||||
{"twitch", []string{"||twitch.tv^", "||ttvnw.net^", "||jtvnw.net^", "||twitchcdn.net^"}},
|
||||
{"netflix", []string{"||nflxext.com^", "||netflix.com^", "||nflximg.net^", "||nflxvideo.net^", "||nflxso.net^"}},
|
||||
{"instagram", []string{"||instagram.com^", "||cdninstagram.com^"}},
|
||||
{"snapchat", []string{
|
||||
"||youtube-nocookie.com^",
|
||||
"||youtube.com^",
|
||||
"||youtubei.googleapis.com^",
|
||||
"||youtubekids.com^",
|
||||
"||ytimg.com^",
|
||||
},
|
||||
}, {
|
||||
name: "twitch",
|
||||
rules: []string{
|
||||
"||jtvnw.net^",
|
||||
"||ttvnw.net^",
|
||||
"||twitch.tv^",
|
||||
"||twitchcdn.net^",
|
||||
},
|
||||
}, {
|
||||
name: "netflix",
|
||||
rules: []string{
|
||||
"||nflxext.com^",
|
||||
"||netflix.com^",
|
||||
"||nflximg.net^",
|
||||
"||nflxvideo.net^",
|
||||
"||nflxso.net^",
|
||||
},
|
||||
}, {
|
||||
name: "instagram",
|
||||
rules: []string{"||instagram.com^", "||cdninstagram.com^"},
|
||||
}, {
|
||||
name: "snapchat",
|
||||
rules: []string{
|
||||
"||snapchat.com^",
|
||||
"||sc-cdn.net^",
|
||||
"||snap-dev.net^",
|
||||
"||snapkit.co",
|
||||
"||snapads.com^",
|
||||
"||impala-media-production.s3.amazonaws.com^",
|
||||
}},
|
||||
{"discord", []string{"||discord.gg^", "||discordapp.net^", "||discordapp.com^", "||discord.com^", "||discord.media^"}},
|
||||
{"ok", []string{"||ok.ru^"}},
|
||||
{"skype", []string{"||skype.com^", "||skypeassets.com^"}},
|
||||
{"vk", []string{"||vk.com^", "||userapi.com^", "||vk-cdn.net^", "||vkuservideo.net^"}},
|
||||
{"origin", []string{"||origin.com^", "||signin.ea.com^", "||accounts.ea.com^"}},
|
||||
{"steam", []string{
|
||||
},
|
||||
}, {
|
||||
name: "discord",
|
||||
rules: []string{
|
||||
"||discord.gg^",
|
||||
"||discordapp.net^",
|
||||
"||discordapp.com^",
|
||||
"||discord.com^",
|
||||
"||discord.gift",
|
||||
"||discord.media^",
|
||||
},
|
||||
}, {
|
||||
name: "ok",
|
||||
rules: []string{"||ok.ru^"},
|
||||
}, {
|
||||
name: "skype",
|
||||
rules: []string{
|
||||
"||edge-skype-com.s-0001.s-msedge.net^",
|
||||
"||skype-edf.akadns.net^",
|
||||
"||skype.com^",
|
||||
"||skypeassets.com^",
|
||||
"||skypedata.akadns.net^",
|
||||
},
|
||||
}, {
|
||||
name: "vk",
|
||||
rules: []string{
|
||||
"||userapi.com^",
|
||||
"||vk-cdn.net^",
|
||||
"||vk.com^",
|
||||
"||vkuservideo.net^",
|
||||
},
|
||||
}, {
|
||||
name: "origin",
|
||||
rules: []string{
|
||||
"||accounts.ea.com^",
|
||||
"||origin.com^",
|
||||
"||signin.ea.com^",
|
||||
},
|
||||
}, {
|
||||
name: "steam",
|
||||
rules: []string{
|
||||
"||steam.com^",
|
||||
"||steampowered.com^",
|
||||
"||steamcommunity.com^",
|
||||
"||steamstatic.com^",
|
||||
"||steamstore-a.akamaihd.net^",
|
||||
"||steamcdn-a.akamaihd.net^",
|
||||
}},
|
||||
{"epic_games", []string{"||epicgames.com^", "||easyanticheat.net^", "||easy.ac^", "||eac-cdn.com^"}},
|
||||
{"reddit", []string{"||reddit.com^", "||redditstatic.com^", "||redditmedia.com^", "||redd.it^"}},
|
||||
{"mail_ru", []string{"||mail.ru^"}},
|
||||
{"cloudflare", []string{
|
||||
"||cloudflare.com^",
|
||||
"||cloudflare-dns.com^",
|
||||
"||cloudflare.net^",
|
||||
"||cloudflareinsights.com^",
|
||||
"||cloudflarestream.com^",
|
||||
"||cloudflareresolve.com^",
|
||||
"||cloudflareclient.com^",
|
||||
"||cloudflarebolt.com^",
|
||||
"||cloudflarestatus.com^",
|
||||
"||cloudflare.cn^",
|
||||
"||one.one^",
|
||||
"||warp.plus^",
|
||||
},
|
||||
}, {
|
||||
name: "epic_games",
|
||||
rules: []string{"||epicgames.com^", "||easyanticheat.net^", "||easy.ac^", "||eac-cdn.com^"},
|
||||
}, {
|
||||
name: "reddit",
|
||||
rules: []string{"||reddit.com^", "||redditstatic.com^", "||redditmedia.com^", "||redd.it^"},
|
||||
}, {
|
||||
name: "mail_ru",
|
||||
rules: []string{"||mail.ru^"},
|
||||
}, {
|
||||
name: "cloudflare",
|
||||
rules: []string{
|
||||
"||1.1.1.1^",
|
||||
"||argotunnel.com^",
|
||||
"||cloudflare-dns.com^",
|
||||
"||cloudflare-ipfs.com^",
|
||||
"||cloudflare-quic.com^",
|
||||
"||cloudflare.cn^",
|
||||
"||cloudflare.com^",
|
||||
"||cloudflare.net^",
|
||||
"||cloudflareapps.com^",
|
||||
"||cloudflarebolt.com^",
|
||||
"||cloudflareclient.com^",
|
||||
"||cloudflareinsights.com^",
|
||||
"||cloudflareresolve.com^",
|
||||
"||cloudflarestatus.com^",
|
||||
"||cloudflarestream.com^",
|
||||
"||cloudflarewarp.com^",
|
||||
"||dns4torpnlfs2ifuz2s2yf3fc7rdmsbhm6rw75euj35pac6ap25zgqad.onion^",
|
||||
}},
|
||||
{"amazon", []string{
|
||||
"||one.one^",
|
||||
"||pages.dev^",
|
||||
"||trycloudflare.com^",
|
||||
"||videodelivery.net^",
|
||||
"||warp.plus^",
|
||||
"||workers.dev^",
|
||||
},
|
||||
}, {
|
||||
name: "amazon",
|
||||
rules: []string{
|
||||
"||amazon.com^",
|
||||
"||media-amazon.com^",
|
||||
"||primevideo.com^",
|
||||
@@ -111,11 +205,14 @@ var serviceRulesArray = []svc{
|
||||
"||amazon.com.br^",
|
||||
"||amazon.co.jp^",
|
||||
"||amazon.com.mx^",
|
||||
"||amazon.com.tr^",
|
||||
"||amazon.co.uk^",
|
||||
"||createspace.com^",
|
||||
"||aws",
|
||||
}},
|
||||
{"ebay", []string{
|
||||
},
|
||||
}, {
|
||||
name: "ebay",
|
||||
rules: []string{
|
||||
"||ebay.com^",
|
||||
"||ebayimg.com^",
|
||||
"||ebaystatic.com^",
|
||||
@@ -141,8 +238,10 @@ var serviceRulesArray = []svc{
|
||||
"||ebay.com.my^",
|
||||
"||ebay.com.sg^",
|
||||
"||ebay.co.uk^",
|
||||
}},
|
||||
{"tiktok", []string{
|
||||
},
|
||||
}, {
|
||||
name: "tiktok",
|
||||
rules: []string{
|
||||
"||tiktok.com^",
|
||||
"||tiktokcdn.com^",
|
||||
"||musical.ly^",
|
||||
@@ -156,65 +255,95 @@ var serviceRulesArray = []svc{
|
||||
"||toutiaocloud.net^",
|
||||
"||bdurl.com^",
|
||||
"||bytecdn.cn^",
|
||||
"||bytedapm.com^",
|
||||
"||byteimg.com^",
|
||||
"||byteoversea.com^",
|
||||
"||ixigua.com^",
|
||||
"||muscdn.com^",
|
||||
"||bytedance.map.fastly.net^",
|
||||
"||douyin.com^",
|
||||
"||tiktokv.com^",
|
||||
}},
|
||||
{"vimeo", []string{
|
||||
"||toutiaovod.com^",
|
||||
"||douyincdn.com^",
|
||||
},
|
||||
}, {
|
||||
name: "vimeo",
|
||||
rules: []string{
|
||||
"*vod-adaptive.akamaized.net^",
|
||||
"||vimeo.com^",
|
||||
"||vimeocdn.com^",
|
||||
"*vod-adaptive.akamaized.net^",
|
||||
}},
|
||||
{"pinterest", []string{
|
||||
"||pinterest.*^",
|
||||
},
|
||||
}, {
|
||||
name: "pinterest",
|
||||
rules: []string{
|
||||
"||pinimg.com^",
|
||||
}},
|
||||
{"imgur", []string{
|
||||
"||imgur.com^",
|
||||
}},
|
||||
{"dailymotion", []string{
|
||||
"||pinterest.*^",
|
||||
},
|
||||
}, {
|
||||
name: "imgur",
|
||||
rules: []string{"||imgur.com^"},
|
||||
}, {
|
||||
name: "dailymotion",
|
||||
rules: []string{
|
||||
"||dailymotion.com^",
|
||||
"||dm-event.net^",
|
||||
"||dmcdn.net^",
|
||||
}},
|
||||
{"qq", []string{
|
||||
// block qq.com and subdomains excluding WeChat domains
|
||||
"||qq.com^$denyallow=wx*.qq.com|weixin.qq.com",
|
||||
},
|
||||
}, {
|
||||
name: "qq",
|
||||
rules: []string{
|
||||
// Block qq.com and subdomains excluding WeChat's domains.
|
||||
"||qq.com^$denyallow=wx.qq.com|weixin.qq.com",
|
||||
"||qqzaixian.com^",
|
||||
}},
|
||||
{"wechat", []string{
|
||||
"||qq-video.cdn-go.cn^",
|
||||
"||url.cn^",
|
||||
},
|
||||
}, {
|
||||
name: "wechat",
|
||||
rules: []string{
|
||||
"||wechat.com^",
|
||||
"||weixin.qq.com.cn^",
|
||||
"||weixin.qq.com^",
|
||||
"||weixinbridge.com^",
|
||||
"||wx.qq.com^",
|
||||
}},
|
||||
{"viber", []string{
|
||||
"||viber.com^",
|
||||
}},
|
||||
{"weibo", []string{
|
||||
},
|
||||
}, {
|
||||
name: "viber",
|
||||
rules: []string{"||viber.com^"},
|
||||
}, {
|
||||
name: "weibo",
|
||||
rules: []string{
|
||||
"||weibo.cn^",
|
||||
"||weibo.com^",
|
||||
}},
|
||||
{"9gag", []string{
|
||||
"||weibocdn.com^",
|
||||
},
|
||||
}, {
|
||||
name: "9gag",
|
||||
rules: []string{
|
||||
"||9cache.com^",
|
||||
"||9gag.com^",
|
||||
}},
|
||||
{"telegram", []string{
|
||||
},
|
||||
}, {
|
||||
name: "telegram",
|
||||
rules: []string{
|
||||
"||t.me^",
|
||||
"||telegram.me^",
|
||||
"||telegram.org^",
|
||||
}},
|
||||
{"disneyplus", []string{
|
||||
},
|
||||
}, {
|
||||
name: "disneyplus",
|
||||
rules: []string{
|
||||
"||disney-plus.net^",
|
||||
"||disneyplus.com^",
|
||||
"||disney.playback.edge.bamgrid.com^",
|
||||
"||media.dssott.com^",
|
||||
}},
|
||||
{"hulu", []string{
|
||||
"||hulu.com^",
|
||||
}},
|
||||
{"spotify", []string{
|
||||
},
|
||||
}, {
|
||||
name: "hulu",
|
||||
rules: []string{"||hulu.com^"},
|
||||
}, {
|
||||
name: "spotify",
|
||||
rules: []string{
|
||||
"/_spotify-connect._tcp.local/",
|
||||
"||spotify.com^",
|
||||
"||scdn.co^",
|
||||
@@ -226,29 +355,59 @@ var serviceRulesArray = []svc{
|
||||
"||audio4-ak-spotify-com.akamaized.net^",
|
||||
"||heads-ak-spotify-com.akamaized.net^",
|
||||
"||heads4-ak-spotify-com.akamaized.net^",
|
||||
}},
|
||||
{"tinder", []string{
|
||||
},
|
||||
}, {
|
||||
name: "tinder",
|
||||
rules: []string{
|
||||
"||gotinder.com^",
|
||||
"||tinder.com^",
|
||||
"||tindersparks.com^",
|
||||
}},
|
||||
}
|
||||
},
|
||||
}, {
|
||||
name: "bilibili",
|
||||
rules: []string{
|
||||
"||biliapi.net^",
|
||||
"||bilibili.com^",
|
||||
"||biligame.com^",
|
||||
"||bilivideo.cn^",
|
||||
"||bilivideo.com^",
|
||||
"||dreamcast.hk^",
|
||||
"||hdslb.com^",
|
||||
},
|
||||
}}
|
||||
|
||||
// convert array to map
|
||||
// serviceRules maps a service ID to its filtering rules.
|
||||
var serviceRules map[string][]*rules.NetworkRule
|
||||
|
||||
// serviceIDs contains service IDs sorted alphabetically.
|
||||
var serviceIDs []string
|
||||
|
||||
// initBlockedServices initializes package-level blocked service data.
|
||||
func initBlockedServices() {
|
||||
serviceRules = make(map[string][]*rules.NetworkRule)
|
||||
for _, s := range serviceRulesArray {
|
||||
netRules := []*rules.NetworkRule{}
|
||||
l := len(servicesData)
|
||||
serviceIDs = make([]string, l)
|
||||
serviceRules = make(map[string][]*rules.NetworkRule, l)
|
||||
|
||||
for i, s := range servicesData {
|
||||
netRules := make([]*rules.NetworkRule, 0, len(s.rules))
|
||||
for _, text := range s.rules {
|
||||
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
|
||||
if err != nil {
|
||||
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
|
||||
log.Error("parsing blocked service %q rule %q: %s", s.name, text, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
netRules = append(netRules, rule)
|
||||
}
|
||||
|
||||
serviceIDs[i] = s.name
|
||||
serviceRules[s.name] = netRules
|
||||
}
|
||||
|
||||
slices.Sort(serviceIDs)
|
||||
|
||||
log.Debug("filtering: initialized %d services", l)
|
||||
}
|
||||
|
||||
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||
@@ -280,6 +439,16 @@ func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleBlockedServicesAvailableServices(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(serviceIDs)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding available services: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||
d.confLock.RLock()
|
||||
list := d.Config.BlockedServices
|
||||
@@ -288,7 +457,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(list)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding services: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -314,6 +483,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
// registerBlockedServicesHandlers - register HTTP handlers
|
||||
func (d *DNSFilter) registerBlockedServicesHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
}
|
||||
|
||||
@@ -61,22 +61,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|
||||
testCasesA := []struct {
|
||||
name string
|
||||
want []interface{}
|
||||
want []any
|
||||
rcode int
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "a-record",
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p1},
|
||||
want: []any{ipv4p1},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa-record",
|
||||
want: []interface{}{ipv6p1},
|
||||
want: []any{ipv6p1},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "txt-record",
|
||||
want: []interface{}{"hello-world"},
|
||||
want: []any{"hello-world"},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeTXT,
|
||||
}, {
|
||||
@@ -86,22 +86,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
dtyp: 0,
|
||||
}, {
|
||||
name: "a-records",
|
||||
want: []interface{}{ipv4p1, ipv4p2},
|
||||
want: []any{ipv4p1, ipv4p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa-records",
|
||||
want: []interface{}{ipv6p1, ipv6p2},
|
||||
want: []any{ipv6p1, ipv6p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "disable-one",
|
||||
want: []interface{}{ipv4p2},
|
||||
want: []any{ipv4p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "disable-cname",
|
||||
want: []interface{}{ipv4p1},
|
||||
want: []any{ipv4p1},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
@@ -94,7 +94,7 @@ type Config struct {
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
@@ -296,9 +296,11 @@ func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
|
||||
return clone
|
||||
}
|
||||
|
||||
// SetFilters - set new filters (synchronously or asynchronously)
|
||||
// When filters are set asynchronously, the old filters continue working until the new filters are ready.
|
||||
// In this case the caller must ensure that the old filter files are intact.
|
||||
// SetFilters sets new filters, synchronously or asynchronously. When filters
|
||||
// are set asynchronously, the old filters continue working until the new
|
||||
// filters are ready.
|
||||
//
|
||||
// In this case the caller must ensure that the old filter files are intact.
|
||||
func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool) error {
|
||||
if async {
|
||||
params := filtersInitializerParams{
|
||||
@@ -471,7 +473,7 @@ func (d *DNSFilter) matchSysHosts(
|
||||
return res, nil
|
||||
}
|
||||
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
@@ -802,7 +804,7 @@ func (d *DNSFilter) matchHost(
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
ureq := urlfilter.DNSRequest{
|
||||
ureq := &urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
|
||||
@@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sbBlocked = "wmconvirus.narod.ru"
|
||||
pcBlocked = "pornhub.com"
|
||||
)
|
||||
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
@@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
d.checkMatch(t, matching)
|
||||
|
||||
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
d.checkMatch(t, sbBlocked)
|
||||
|
||||
d.checkMatch(t, "test."+matching)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
|
||||
// Cached result.
|
||||
d.safeBrowsingServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, "test."+matching)
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
|
||||
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.checkMatch(t, matching)
|
||||
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.checkMatch(t, pcBlocked)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||
|
||||
d.checkMatch(t, "www."+matching)
|
||||
d.checkMatch(t, "www."+pcBlocked)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
}
|
||||
|
||||
@@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
@@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
|
||||
}},
|
||||
)
|
||||
t.Cleanup(d.Close)
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "pornhub.com",
|
||||
Block: true,
|
||||
})
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "wmconvirus.narod.ru",
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
|
||||
wantReason: FilteredBlockList,
|
||||
}, {
|
||||
name: "parental",
|
||||
host: "pornhub.com",
|
||||
host: pcBlocked,
|
||||
before: true,
|
||||
wantReason: FilteredParental,
|
||||
}, {
|
||||
name: "safebrowsing",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
before: false,
|
||||
wantReason: FilteredSafeBrowsing,
|
||||
}, {
|
||||
@@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -130,10 +130,9 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
//
|
||||
// The sorting priority:
|
||||
//
|
||||
// A and AAAA > CNAME
|
||||
// wildcard > exact
|
||||
// lower level wildcard > higher level wildcard
|
||||
//
|
||||
// 1. A and AAAA > CNAME
|
||||
// 2. wildcard > exact
|
||||
// 3. lower level wildcard > higher level wildcard
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
||||
|
||||
@@ -24,10 +24,11 @@ import (
|
||||
|
||||
// Safe browsing and parental control methods.
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
const (
|
||||
dnsTimeout = 3 * time.Second
|
||||
defaultSafebrowsingServer = `https://dns-family.adguard.com/dns-query`
|
||||
defaultParentalServer = `https://dns-family.adguard.com/dns-query`
|
||||
defaultSafebrowsingServer = `https://family.adguard-dns.com/dns-query`
|
||||
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
|
||||
sbTXTSuffix = `sb.dns.adguard.com.`
|
||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||
)
|
||||
@@ -313,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
|
||||
defer timer.LogElapsed("safebrowsing lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
@@ -347,7 +348,7 @@ func (d *DNSFilter) checkParental(
|
||||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("Parental lookup for %s", host)
|
||||
defer timer.LogElapsed("parental lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
|
||||
@@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
assert.Equal(t, -1, c.getCached())
|
||||
|
||||
// match "sub.host.com" from cache,
|
||||
// but another hash for "nonexisting.com" is not in cache
|
||||
// which means that we must get data from server for it
|
||||
// Match "sub.host.com" from cache. Another hash for "host.example" is not
|
||||
// in the cache, so get data for it from the server.
|
||||
c.hashToHost = make(map[[32]byte]string)
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
c.hashToHost[hash] = "nonexisting.com"
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
c.hashToHost[hash] = "host.example"
|
||||
assert.Empty(t, c.getCached())
|
||||
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
_, ok := c.hashToHost[hash]
|
||||
assert.False(t, ok)
|
||||
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
_, ok = c.hashToHost[hash]
|
||||
assert.True(t, ok)
|
||||
|
||||
@@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
|
||||
ups := aghtest.NewErrorUpstream()
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
@@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
ups := aghtest.NewBlockUpstream(hostname, tc.block)
|
||||
|
||||
var numReq int
|
||||
onExchange := ups.OnExchange
|
||||
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
numReq++
|
||||
|
||||
return onExchange(req)
|
||||
}
|
||||
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
@@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, hits, tc.testCache.Stats().Hit)
|
||||
|
||||
// There was one request to an upstream.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
|
||||
// Now make the same request to check the cache was used.
|
||||
res, err = tc.testFunc(hostname, dns.TypeA, setts)
|
||||
@@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
|
||||
|
||||
// Check that there were no additional requests.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
})
|
||||
|
||||
purgeCaches(d)
|
||||
|
||||
@@ -2,9 +2,11 @@ package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -53,12 +55,49 @@ type clientSource uint
|
||||
// Client sources. The order determines the priority.
|
||||
const (
|
||||
ClientSourceWHOIS clientSource = iota
|
||||
ClientSourceRDNS
|
||||
ClientSourceARP
|
||||
ClientSourceRDNS
|
||||
ClientSourceDHCP
|
||||
ClientSourceHostsFile
|
||||
)
|
||||
|
||||
var _ fmt.Stringer = clientSource(0)
|
||||
|
||||
// String returns a human-readable name of cs.
|
||||
func (cs clientSource) String() (s string) {
|
||||
switch cs {
|
||||
case ClientSourceWHOIS:
|
||||
return "WHOIS"
|
||||
case ClientSourceARP:
|
||||
return "ARP"
|
||||
case ClientSourceRDNS:
|
||||
return "rDNS"
|
||||
case ClientSourceDHCP:
|
||||
return "DHCP"
|
||||
case ClientSourceHostsFile:
|
||||
return "etc/hosts"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var _ encoding.TextMarshaler = clientSource(0)
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler for the clientSource.
|
||||
func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
return []byte(cs.String()), nil
|
||||
}
|
||||
|
||||
// clientSourceConf is used to configure where the runtime clients will be
|
||||
// obtained from.
|
||||
type clientSourcesConf struct {
|
||||
WHOIS bool `yaml:"whois"`
|
||||
ARP bool `yaml:"arp"`
|
||||
RDNS bool `yaml:"rdns"`
|
||||
DHCP bool `yaml:"dhcp"`
|
||||
HostsFile bool `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// RuntimeClient information
|
||||
type RuntimeClient struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo
|
||||
@@ -134,14 +173,14 @@ func (clients *clientsContainer) Init(
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
|
||||
go clients.handleHostsUpdates()
|
||||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
if clients.etcHosts != nil {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +197,9 @@ func (clients *clientsContainer) Start() {
|
||||
|
||||
// Reload reloads runtime clients.
|
||||
func (clients *clientsContainer) Reload() {
|
||||
clients.addFromSystemARP()
|
||||
if clients.arpdb != nil {
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
}
|
||||
|
||||
type clientObject struct {
|
||||
@@ -257,6 +298,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
defer log.OnPanic("clients container")
|
||||
|
||||
for {
|
||||
clients.Reload()
|
||||
time.Sleep(clientsUpdatePeriod)
|
||||
@@ -382,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
c.Tags = stringutil.CloneSlice(c.Tags)
|
||||
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return c, true
|
||||
}
|
||||
|
||||
@@ -478,7 +522,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = clients.ipToRC.Get(ip)
|
||||
if !ok {
|
||||
return nil, false
|
||||
@@ -532,7 +576,7 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.IDs[i] = mac.String()
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
c.IDs[i] = id
|
||||
c.IDs[i] = strings.ToLower(id)
|
||||
} else {
|
||||
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
|
||||
}
|
||||
@@ -723,20 +767,18 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
ok = clients.addHostLocked(ip, host, src)
|
||||
|
||||
return ok, nil
|
||||
return clients.addHostLocked(ip, host, src), nil
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
} else {
|
||||
rc = &RuntimeClient{
|
||||
@@ -756,7 +798,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -784,26 +826,21 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
hosts, ok := v.(*stringutil.Set)
|
||||
hosts.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rec, ok := v.(*aghnet.HostsRecord)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
hosts.Range(func(name string) (cont bool) {
|
||||
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
|
||||
n++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
|
||||
n++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@@ -831,7 +868,7 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, "", ClientSourceARP) {
|
||||
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
@@ -842,7 +879,7 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
|
||||
// server.
|
||||
func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
if clients.dhcpServer == nil {
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -47,9 +47,9 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -70,7 +70,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
cj := runtimeClientJSON{
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
|
||||
Name: rc.Host,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
switch rc.Source {
|
||||
case ClientSourceDHCP:
|
||||
cj.Source = "DHCP"
|
||||
case ClientSourceRDNS:
|
||||
cj.Source = "rDNS"
|
||||
case ClientSourceARP:
|
||||
cj.Source = "ARP"
|
||||
case ClientSourceWHOIS:
|
||||
cj.Source = "WHOIS"
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
@@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to encode to json: %v",
|
||||
e,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,15 +28,36 @@ const (
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// logSettings
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
// TODO(a.garipov): Put them into a separate object.
|
||||
type logSettings struct {
|
||||
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
|
||||
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
|
||||
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
|
||||
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
|
||||
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
|
||||
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
// File is the path to the log file. If empty, logs are written to stdout.
|
||||
// If "syslog", logs are written to syslog.
|
||||
File string `yaml:"log_file"`
|
||||
|
||||
// MaxBackups is the maximum number of old log files to retain.
|
||||
//
|
||||
// NOTE: MaxAge may still cause them to get deleted.
|
||||
MaxBackups int `yaml:"log_max_backups"`
|
||||
|
||||
// MaxSize is the maximum size of the log file before it gets rotated, in
|
||||
// megabytes. The default value is 100 MB.
|
||||
MaxSize int `yaml:"log_max_size"`
|
||||
|
||||
// MaxAge is the maximum duration for retaining old log files, in days.
|
||||
MaxAge int `yaml:"log_max_age"`
|
||||
|
||||
// Compress determines, if the rotated log files should be compressed using
|
||||
// gzip.
|
||||
Compress bool `yaml:"log_compress"`
|
||||
|
||||
// LocalTime determines, if the time used for formatting the timestamps in
|
||||
// is the computer's local time.
|
||||
LocalTime bool `yaml:"log_localtime"`
|
||||
|
||||
// Verbose determines, if verbose (aka debug) logging is enabled.
|
||||
Verbose bool `yaml:"verbose"`
|
||||
}
|
||||
|
||||
// osConfig contains OS-related configuration.
|
||||
@@ -51,6 +73,13 @@ type osConfig struct {
|
||||
RlimitNoFile uint64 `yaml:"rlimit_nofile"`
|
||||
}
|
||||
|
||||
type clientsConfig struct {
|
||||
// Sources defines the set of sources to fetch the runtime clients from.
|
||||
Sources *clientSourcesConf `yaml:"runtime_sources"`
|
||||
// Persistent are the configured clients.
|
||||
Persistent []*clientObject `yaml:"persistent"`
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
@@ -88,7 +117,7 @@ type configuration struct {
|
||||
// Clients contains the YAML representations of the persistent clients.
|
||||
// This field is only used for reading and writing persistent client data.
|
||||
// Keep this field sorted to ensure consistent ordering.
|
||||
Clients []*clientObject `yaml:"clients"`
|
||||
Clients *clientsConfig `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
|
||||
@@ -123,8 +152,9 @@ type dnsConfig struct {
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
|
||||
// ResolveClients enables and disables resolving clients with RDNS.
|
||||
ResolveClients bool `yaml:"resolve_clients"`
|
||||
// PrivateNets is the set of IP networks for which the private reverse DNS
|
||||
// resolver should be used.
|
||||
PrivateNets []string `yaml:"private_networks"`
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
@@ -179,6 +209,7 @@ var config = &configuration{
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
HandleDDR: true,
|
||||
FastestTimeout: timeutil.Duration{
|
||||
Duration: fastip.DefaultPingWaitTimeout,
|
||||
},
|
||||
@@ -194,7 +225,6 @@ var config = &configuration{
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
ResolveClients: true,
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
@@ -205,12 +235,21 @@ var config = &configuration{
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
},
|
||||
Clients: &clientsConfig{
|
||||
Sources: &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: true,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
logSettings: logSettings{
|
||||
LogCompress: false,
|
||||
LogLocalTime: false,
|
||||
LogMaxBackups: 0,
|
||||
LogMaxSize: 100,
|
||||
LogMaxAge: 3,
|
||||
Compress: false,
|
||||
LocalTime: false,
|
||||
MaxBackups: 0,
|
||||
MaxSize: 100,
|
||||
MaxAge: 3,
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
@@ -285,25 +324,28 @@ func parseConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
config.BindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Consider adding a udpPort with the same value when
|
||||
// we add support for HTTP/3 for web admin interface.
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
@@ -317,8 +359,14 @@ func parseConfig() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPorts is a helper for ports validation. It skips zero ports.
|
||||
func addPorts(uc aghalg.UniqChecker, ports ...int) {
|
||||
// udpPort is the port number for UDP protocol.
|
||||
type udpPort int
|
||||
|
||||
// tcpPort is the port number for TCP protocol.
|
||||
type tcpPort int
|
||||
|
||||
// addPorts is a helper for ports validation that skips zero ports.
|
||||
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
|
||||
for _, p := range ports {
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
@@ -340,13 +388,14 @@ func readConfigFile() (fileData []byte, err error) {
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() error {
|
||||
func (c *configuration) write() (err error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
@@ -380,9 +429,7 @@ func (c *configuration) write() error {
|
||||
s.WriteDiskConfig(&c)
|
||||
dns := &config.DNS
|
||||
dns.FilteringConfig = c
|
||||
dns.LocalPTRResolvers,
|
||||
dns.ResolveClients,
|
||||
dns.UsePrivateRDNS = s.RDNSSettings()
|
||||
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
@@ -391,22 +438,23 @@ func (c *configuration) write() error {
|
||||
config.DHCP = c
|
||||
}
|
||||
|
||||
config.Clients = Context.clients.forConfig()
|
||||
config.Clients.Persistent = Context.clients.forConfig()
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
log.Debug("writing config file %q", configFile)
|
||||
|
||||
return err
|
||||
buf := &bytes.Buffer{}
|
||||
enc := yaml.NewEncoder(buf)
|
||||
enc.SetIndent(2)
|
||||
|
||||
err = enc.Encode(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating config file: %w", err)
|
||||
}
|
||||
|
||||
err = maybe.WriteFile(configFile, yamlText, 0o644)
|
||||
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
|
||||
if err != nil {
|
||||
log.Error("Couldn't save YAML config: %s", err)
|
||||
|
||||
return err
|
||||
return fmt.Errorf("writing config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -189,7 +189,7 @@ func registerControlHandlers() {
|
||||
RegisterAuthHandlers()
|
||||
}
|
||||
|
||||
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
Context.mux.HandleFunc(url, postInstall(handler))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -57,8 +58,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
err = validateFilterURL(fj.URL)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("invalid url: %s", err)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -178,16 +179,16 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
type filterURLJSON struct {
|
||||
type filterURLReqData struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filterURLReq struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
Data filterURLJSON `json:"data"`
|
||||
Data *filterURLReqData `json:"data"`
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -199,10 +200,17 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if fj.Data == nil {
|
||||
err = errors.Error("data cannot be null")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = validateFilterURL(fj.Data.URL)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("invalid url: %s", err)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -223,11 +231,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
restart := false
|
||||
if (status & statusEnabledChanged) != 0 {
|
||||
// we must add or remove filter rules
|
||||
restart = true
|
||||
}
|
||||
|
||||
restart := (status & statusEnabledChanged) != 0
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := filterRefreshBlocklists
|
||||
@@ -242,6 +247,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
restart = true
|
||||
}
|
||||
}
|
||||
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
}
|
||||
@@ -311,20 +317,20 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
LastUpdated string `json:"last_updated,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filteringConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Filters []filterJSON `json:"filters"`
|
||||
WhitelistFilters []filterJSON `json:"whitelist_filters"`
|
||||
UserRules []string `json:"user_rules"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
@@ -402,16 +408,12 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
type checkHostRespRule struct {
|
||||
FilterListID int64 `json:"filter_list_id"`
|
||||
Text string `json:"text"`
|
||||
FilterListID int64 `json:"filter_list_id"`
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
Reason string `json:"reason"`
|
||||
// FilterID is the ID of the rule's filter list.
|
||||
//
|
||||
// Deprecated: Use Rules[*].FilterListID.
|
||||
FilterID int64 `json:"filter_id"`
|
||||
|
||||
// Rule is the text of the matched rule.
|
||||
//
|
||||
@@ -426,6 +428,11 @@ type checkHostResp struct {
|
||||
// for Rewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
|
||||
// FilterID is the ID of the rule's filter list.
|
||||
//
|
||||
// Deprecated: Use Rules[*].FilterListID.
|
||||
FilterID int64 `json:"filter_id"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -105,19 +105,22 @@ type checkConfResp struct {
|
||||
|
||||
// validateWeb returns error is the web part if the initial configuration can't
|
||||
// be set.
|
||||
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.Web.Port
|
||||
addPorts(uc, config.BetaBindPort, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
// Avoid duplicating the error into the status of DNS.
|
||||
uc[port] = 1
|
||||
portInt := req.Web.Port
|
||||
port := tcpPort(portInt)
|
||||
addPorts(tcpPorts, tcpPort(config.BetaBindPort), port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
// Reset the value for the port to 1 to make sure that validateDNS
|
||||
// doesn't throw the same error, unless the same TCP port is set there
|
||||
// as well.
|
||||
tcpPorts[port] = 1
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
switch port {
|
||||
switch portInt {
|
||||
case 0, config.BindPort:
|
||||
return nil
|
||||
default:
|
||||
@@ -125,21 +128,18 @@ func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, port)
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
// be set. canAutofix is true if the port can be unbound by AdGuard Home
|
||||
// automatically.
|
||||
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
|
||||
func (req *checkConfReq) validateDNS(
|
||||
tcpPorts aghalg.UniqChecker[tcpPort],
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
addPorts(uc, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
@@ -148,6 +148,11 @@ func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, er
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
// Check TCP as well.
|
||||
addPorts(tcpPorts, tcpPort(port))
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -185,13 +190,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
resp := &checkConfResp{}
|
||||
uc := aghalg.UniqChecker{}
|
||||
|
||||
if err = req.validateWeb(uc); err != nil {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
if err = req.validateWeb(tcpPorts); err != nil {
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
@@ -212,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := aghnet.GetInterfaceByIP(ip)
|
||||
interfaceName := aghnet.InterfaceByIP(ip)
|
||||
resp.Static = "no"
|
||||
|
||||
if len(interfaceName) == 0 {
|
||||
|
||||
@@ -3,14 +3,15 @@ package home
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
@@ -27,12 +28,16 @@ type temporaryError interface {
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
resp := &versionResponse{}
|
||||
if Context.disableUpdate {
|
||||
// w.Header().Set("Content-Type", "application/json")
|
||||
resp.Disabled = true
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
// TODO(e.burkov): Add error handling and deal with headers.
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,30 +49,48 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = requestVersionInfo(resp, req.Recheck)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = resp.setAllowedToAutoUpdate()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
// update server.
|
||||
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
for i := 0; i != 3; i++ {
|
||||
func() {
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
|
||||
}()
|
||||
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
|
||||
if err != nil {
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
// Temporary network error. This case may happen while
|
||||
// we're restarting our DNS server. Log and sleep for
|
||||
// some time.
|
||||
// Temporary network error. This case may happen while we're
|
||||
// restarting our DNS server. Log and sleep for some time.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
d := time.Duration(i) * time.Second
|
||||
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
time.Sleep(d)
|
||||
|
||||
continue
|
||||
@@ -76,29 +99,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
// TODO(a.garipov): Figure out the purpose of %T verb.
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadGateway,
|
||||
"Couldn't get version check json from %s: %T %s\n",
|
||||
vcu,
|
||||
err,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
return fmt.Errorf("getting version info from %s: %s", vcu, err)
|
||||
}
|
||||
|
||||
resp.confirmAutoUpdate()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUpdate performs an update to the latest available version procedure.
|
||||
@@ -109,7 +117,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.Update()
|
||||
// Retain the current absolute path of the executable, since the updater is
|
||||
// likely to change the position current one to the backup directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4735.
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "getting path: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
@@ -121,85 +140,88 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the underlying functions wrap
|
||||
// it with timeout and shut down the server, which handles current
|
||||
// request. It also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
go func() {
|
||||
finishUpdate(context.Background())
|
||||
}()
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), execPath)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
type versionResponse struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
updater.VersionInfo
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
// confirmAutoUpdate checks the real possibility of auto update.
|
||||
func (vr *versionResponse) confirmAutoUpdate() {
|
||||
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
|
||||
canUpdate := true
|
||||
|
||||
var tlsConf *tlsConfigSettings
|
||||
if runtime.GOOS != "windows" {
|
||||
tlsConf = &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
}
|
||||
|
||||
if tlsConf != nil &&
|
||||
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
|
||||
tlsConf.PortDNSOverTLS < 1024 ||
|
||||
tlsConf.PortDNSOverQUIC < 1024)) ||
|
||||
config.BindPort < 1024 ||
|
||||
config.DNS.Port < 1024) {
|
||||
canUpdate, _ = aghnet.CanBindPrivilegedPorts()
|
||||
}
|
||||
vr.CanAutoUpdate = &canUpdate
|
||||
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
|
||||
// allowed to perform an automatic update by the OS.
|
||||
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
if vr.CanAutoUpdate != aghalg.NBTrue {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
vr.CanAutoUpdate = aghalg.BoolToNullBool(canUpdate)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
||||
// indicates that privileged ports are used.
|
||||
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context) {
|
||||
log.Info("Stopping all tasks")
|
||||
func finishUpdate(ctx context.Context, execPath string) {
|
||||
var err error
|
||||
|
||||
log.Info("stopping all tasks")
|
||||
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
if runtime.GOOS == "windows" {
|
||||
exeName = "AdGuardHome.exe"
|
||||
}
|
||||
curBinName := filepath.Join(Context.workDir, exeName)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if Context.runningAsService {
|
||||
// Note:
|
||||
// we can't restart the service via "kardianos/service" package - it kills the process first
|
||||
// we can't start a new instance - Windows doesn't allow it
|
||||
// NOTE: We can't restart the service via "kardianos/service"
|
||||
// package, because it kills the process first we can't start a new
|
||||
// instance, because Windows doesn't allow it.
|
||||
//
|
||||
// TODO(a.garipov): Recheck the claim above.
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting: stopping: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cmd := exec.Command(curBinName, os.Args[1:]...)
|
||||
log.Info("Restarting: %v", cmd.Args)
|
||||
cmd := exec.Command(execPath, os.Args[1:]...)
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting:: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
} else {
|
||||
log.Info("Restarting: %v", os.Args)
|
||||
err := syscall.Exec(curBinName, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("syscall.Exec() failed: %s", err)
|
||||
}
|
||||
// Unreachable code
|
||||
}
|
||||
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
err = syscall.Exec(execPath, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("restarting: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Default ports.
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
defaultPortHTTPS = 443
|
||||
defaultPortQUIC = 784
|
||||
defaultPortQUIC = 853
|
||||
defaultPortTLS = 853
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
FindClient: Context.clients.findMultiple,
|
||||
@@ -67,7 +68,6 @@ func initDNSServer() (err error) {
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
Anonymizer: anonymizer,
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
@@ -77,13 +77,36 @@ func initDNSServer() (err error) {
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
Context.dnsFilter = filtering.New(&filterConf, nil)
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
var n *net.IPNet
|
||||
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = n
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.dnsFilter,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
SubnetDetector: Context.subnetDetector,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
DNSFilter: Context.dnsFilter,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
@@ -112,8 +135,13 @@ func initDNSServer() (err error) {
|
||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||
}
|
||||
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
if config.Clients.Sources.RDNS {
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
}
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
@@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
return
|
||||
}
|
||||
|
||||
if config.DNS.ResolveClients && !ip.IsLoopback() {
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.subnetDetector.IsSpecialNetwork(ip) {
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
@@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
|
||||
}
|
||||
@@ -216,7 +249,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.FilterHandler = applyAdditionalFiltering
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
|
||||
|
||||
newConf.ResolveClients = dnsConf.ResolveClients
|
||||
newConf.ResolveClients = config.Clients.Sources.RDNS
|
||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
|
||||
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
|
||||
@@ -301,24 +334,28 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filtering.Settings) {
|
||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
if clientAddr == nil {
|
||||
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
|
||||
|
||||
if clientIP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
setts.ClientIP = clientAddr
|
||||
setts.ClientIP = clientIP
|
||||
|
||||
c, ok := Context.clients.Find(clientID)
|
||||
if !ok {
|
||||
c, ok = Context.clients.Find(clientAddr.String())
|
||||
c, ok = Context.clients.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
|
||||
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
@@ -359,11 +396,16 @@ func startDNSServer() error {
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
|
||||
if config.DNS.ResolveClients && !ip.IsLoopback() {
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.subnetDetector.IsSpecialNetwork(ip) {
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
@@ -413,7 +455,12 @@ func closeDNSServer() {
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
Context.stats.Close()
|
||||
err := Context.stats.Close()
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Find out if it's safe.
|
||||
Context.stats = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ type homeContext struct {
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
stats stats.Interface // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
@@ -66,8 +66,6 @@ type homeContext struct {
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
|
||||
// mux is our custom http.ServeMux.
|
||||
mux *http.ServeMux
|
||||
|
||||
@@ -175,6 +173,11 @@ func setupContext(args options) {
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
err = setupHostsContainer()
|
||||
fatalOnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
Context.mux = http.NewServeMux()
|
||||
@@ -287,41 +290,35 @@ func setupConfig(args options) (err error) {
|
||||
ConfName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
if !args.noEtcHosts {
|
||||
if err = setupHostsContainer(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var arpdb aghnet.ARPDB
|
||||
arpdb, err = aghnet.NewARPDB()
|
||||
if err != nil {
|
||||
log.Info("warning: creating arpdb: %s; using stub", err)
|
||||
|
||||
arpdb = aghnet.EmptyARPDB{}
|
||||
if config.Clients.Sources.ARP {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
|
||||
if args.bindPort != 0 {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
args.bindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = args.bindPort
|
||||
@@ -398,9 +395,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
// configure log level and output
|
||||
configureLogger(args)
|
||||
|
||||
// Go memory hacks
|
||||
memoryUsage(args)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Println(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
@@ -477,9 +471,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
Context.web, err = initWeb(args, clientBuildFS)
|
||||
fatalOnError(err)
|
||||
|
||||
Context.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNSServer()
|
||||
fatalOnError(err)
|
||||
@@ -529,27 +520,15 @@ func StartMods() error {
|
||||
func checkPermissions() {
|
||||
log.Info("Checking if AdGuard Home has necessary permissions")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows we need to have admin rights to run properly
|
||||
|
||||
admin, _ := aghos.HaveAdminRights()
|
||||
if admin {
|
||||
return
|
||||
}
|
||||
|
||||
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
|
||||
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
|
||||
}
|
||||
|
||||
// We should check if AdGuard Home is able to bind to port 53
|
||||
ok, err := aghnet.CanBindPort(53)
|
||||
|
||||
if ok {
|
||||
log.Info("AdGuard Home can bind to port 53")
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
msg := `Permission check failed.
|
||||
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
log.Fatal(`Permission check failed.
|
||||
|
||||
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
|
||||
Please note, that this is crucial for a server to be able to use privileged ports.
|
||||
@@ -557,16 +536,17 @@ Please note, that this is crucial for a server to be able to use privileged port
|
||||
You have two options:
|
||||
1. Run AdGuard Home with root privileges
|
||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
|
||||
}
|
||||
|
||||
log.Fatal(msg)
|
||||
log.Info(
|
||||
"AdGuard failed to bind to port 53: %s\n\n"+
|
||||
"Please note, that this is crucial for a DNS server to be able to use that port.",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
|
||||
|
||||
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
|
||||
|
||||
log.Info(msg)
|
||||
log.Info("AdGuard Home can bind to port 53")
|
||||
}
|
||||
|
||||
// Write PID to a file
|
||||
@@ -622,17 +602,17 @@ func configureLogger(args options) {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.LogFile = args.logFile
|
||||
} else if config.LogFile != "" {
|
||||
ls.LogFile = config.LogFile
|
||||
ls.File = args.logFile
|
||||
} else if config.File != "" {
|
||||
ls.File = config.File
|
||||
}
|
||||
|
||||
// Handle default log settings overrides
|
||||
ls.LogCompress = config.LogCompress
|
||||
ls.LogLocalTime = config.LogLocalTime
|
||||
ls.LogMaxBackups = config.LogMaxBackups
|
||||
ls.LogMaxSize = config.LogMaxSize
|
||||
ls.LogMaxAge = config.LogMaxAge
|
||||
ls.Compress = config.Compress
|
||||
ls.LocalTime = config.LocalTime
|
||||
ls.MaxBackups = config.MaxBackups
|
||||
ls.MaxSize = config.MaxSize
|
||||
ls.MaxAge = config.MaxAge
|
||||
|
||||
// log.SetLevel(log.INFO) - default
|
||||
if ls.Verbose {
|
||||
@@ -643,27 +623,27 @@ func configureLogger(args options) {
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing
|
||||
// else is configured. Otherwise, we'll simply lose the log output.
|
||||
ls.LogFile = configSyslog
|
||||
ls.File = configSyslog
|
||||
}
|
||||
|
||||
// logs are written to stdout (default)
|
||||
if ls.LogFile == "" {
|
||||
if ls.File == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ls.LogFile == configSyslog {
|
||||
if ls.File == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows
|
||||
err := aghos.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
|
||||
if filepath.IsAbs(ls.LogFile) {
|
||||
logFilePath = ls.LogFile
|
||||
logFilePath := filepath.Join(Context.workDir, ls.File)
|
||||
if filepath.IsAbs(ls.File) {
|
||||
logFilePath = ls.File
|
||||
}
|
||||
|
||||
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
@@ -673,11 +653,11 @@ func configureLogger(args options) {
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.LogCompress, // disabled by default
|
||||
LocalTime: ls.LogLocalTime,
|
||||
MaxBackups: ls.LogMaxBackups,
|
||||
MaxSize: ls.LogMaxSize, // megabytes
|
||||
MaxAge: ls.LogMaxAge, // days
|
||||
Compress: ls.Compress, // disabled by default
|
||||
LocalTime: ls.LocalTime,
|
||||
MaxBackups: ls.MaxBackups,
|
||||
MaxSize: ls.MaxSize, // megabytes
|
||||
MaxAge: ls.MaxAge, // days
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
var allowedLanguages = stringutil.NewSet(
|
||||
"ar",
|
||||
"be",
|
||||
"bg",
|
||||
"cs",
|
||||
@@ -50,7 +51,7 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"zh-tw",
|
||||
)
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
@@ -58,6 +59,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -69,6 +71,7 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
|
||||
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
|
||||
// See this for the details on the performance hits & gains:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
|
||||
func memoryUsage(args options) {
|
||||
if args.disableMemoryOptimization {
|
||||
log.Info("Memory optimization is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Makes Go allocate heap at a slower pace
|
||||
// By default we keep it at 50%
|
||||
debug.SetGCPercent(50)
|
||||
|
||||
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
|
||||
// instead of MADV_FREE on Linux when returning memory to the
|
||||
// kernel. This is less efficient, but causes RSS numbers to drop
|
||||
// more quickly.
|
||||
_ = os.Setenv("GODEBUG", "madvdontneed=1")
|
||||
|
||||
// periodically call "debug.FreeOSMemory" so
|
||||
// that the OS could reclaim the free memory
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
log.Debug("free os memory")
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/google/uuid"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
@@ -47,9 +47,9 @@ type payloadContent struct {
|
||||
|
||||
PayloadType string
|
||||
PayloadIdentifier string
|
||||
PayloadUUID string
|
||||
PayloadDisplayName string
|
||||
PayloadDescription string
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
}
|
||||
|
||||
@@ -63,18 +63,14 @@ const dnsSettingsPayloadType = "com.apple.dnsSettings.managed"
|
||||
type mobileConfig struct {
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadContent []*payloadContent
|
||||
PayloadIdentifier uuid.UUID
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
PayloadRemovalDisallowed bool
|
||||
}
|
||||
|
||||
func genUUIDv4() string {
|
||||
return uuid.NewV4().String()
|
||||
}
|
||||
|
||||
const (
|
||||
dnsProtoHTTPS = "HTTPS"
|
||||
dnsProtoTLS = "TLS"
|
||||
@@ -104,23 +100,23 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("bad dns protocol %q", proto)
|
||||
}
|
||||
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, genUUIDv4())
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, uuid.New())
|
||||
data := &mobileConfig{
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur " +
|
||||
"and iOS 14 or newer systems",
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur and iOS 14 or newer systems",
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadIdentifier: genUUIDv4(),
|
||||
PayloadType: "Configuration",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadContent: []*payloadContent{{
|
||||
DNSSettings: d,
|
||||
|
||||
PayloadType: dnsSettingsPayloadType,
|
||||
PayloadIdentifier: payloadID,
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadDescription: "Configures device to use AdGuard Home",
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
DNSSettings: d,
|
||||
}},
|
||||
PayloadIdentifier: uuid.New(),
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
PayloadRemovalDisallowed: false,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// options passed from command-line arguments
|
||||
@@ -27,10 +28,6 @@ type options struct {
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
|
||||
// disableMemoryOptimization - disables memory optimization hacks
|
||||
// see memoryUsage() function for the details
|
||||
disableMemoryOptimization bool
|
||||
|
||||
glinetMode bool // Activate GL-Inet compatibility mode
|
||||
|
||||
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be
|
||||
@@ -178,10 +175,14 @@ var noCheckUpdateArg = arg{
|
||||
}
|
||||
|
||||
var disableMemoryOptimizationArg = arg{
|
||||
"Disable memory optimization.",
|
||||
"Deprecated. Disable memory optimization.",
|
||||
"no-mem-optimization", "",
|
||||
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
|
||||
nil, nil, func(_ options, _ string) (f effect, err error) {
|
||||
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var verboseArg = arg{
|
||||
@@ -229,13 +230,19 @@ var helpArg = arg{
|
||||
}
|
||||
|
||||
var noEtcHostsArg = arg{
|
||||
description: "Do not use the OS-provided hosts.",
|
||||
description: "Deprecated. Do not use the OS-provided hosts.",
|
||||
longName: "no-etc-hosts",
|
||||
shortName: "",
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
|
||||
effect: func(_ options, _ string) (f effect, err error) {
|
||||
log.Info(
|
||||
"warning: --no-etc-hosts flag is deprecated and will be removed in the future versions",
|
||||
)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
|
||||
}
|
||||
|
||||
var localFrontendArg = arg{
|
||||
|
||||
@@ -101,9 +101,13 @@ func TestParseDisableUpdate(t *testing.T) {
|
||||
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove after v0.108.0.
|
||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
|
||||
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
|
||||
o, eff, err := parse("", []string{"--no-mem-optimization"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, eff)
|
||||
assert.Zero(t, o)
|
||||
}
|
||||
|
||||
func TestParseService(t *testing.T) {
|
||||
@@ -127,8 +131,6 @@ func TestParseUnknown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
const reportFmt = "expected %s but got %s"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
opts options
|
||||
@@ -173,19 +175,14 @@ func TestSerialize(t *testing.T) {
|
||||
name: "glinet_mode",
|
||||
opts: options{glinetMode: true},
|
||||
ss: []string{"--glinet"},
|
||||
}, {
|
||||
name: "disable_mem_opt",
|
||||
opts: options{disableMemoryOptimization: true},
|
||||
ss: []string{"--no-mem-optimization"},
|
||||
}, {
|
||||
name: "multiple",
|
||||
opts: options{
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
disableMemoryOptimization: true,
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
},
|
||||
ss: []string{
|
||||
"-c", "config",
|
||||
@@ -193,18 +190,13 @@ func TestSerialize(t *testing.T) {
|
||||
"-s", "run",
|
||||
"--pidfile", "pid",
|
||||
"--no-check-update",
|
||||
"--no-mem-optimization",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := serialize(tc.opts)
|
||||
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
|
||||
|
||||
for i, r := range result {
|
||||
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
|
||||
}
|
||||
assert.ElementsMatch(t, tc.ss, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,18 +16,17 @@ type RDNS struct {
|
||||
exchanger dnsforward.RDNSExchanger
|
||||
clients *clientsContainer
|
||||
|
||||
// usePrivate is used to store the state of current private RDNS
|
||||
// resolving settings and to react to it's changes.
|
||||
// usePrivate is used to store the state of current private RDNS resolving
|
||||
// settings and to react to it's changes.
|
||||
usePrivate uint32
|
||||
|
||||
// ipCh used to pass client's IP to rDNS workerLoop.
|
||||
ipCh chan net.IP
|
||||
|
||||
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
||||
// address stays here while it's inside clients. After leaving clients
|
||||
// the address will be resolved once again. If the address couldn't be
|
||||
// resolved, cache prevents further attempts to resolve it for some
|
||||
// time.
|
||||
// address stays here while it's inside clients. After leaving clients the
|
||||
// address will be resolved once again. If the address couldn't be
|
||||
// resolved, cache prevents further attempts to resolve it for some time.
|
||||
ipCache cache.Cache
|
||||
}
|
||||
|
||||
@@ -125,14 +124,12 @@ func (r *RDNS) workerLoop() {
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
} else if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't handle any errors since AddHost doesn't return non-nil
|
||||
// errors for now.
|
||||
// Don't handle any errors since AddHost doesn't return non-nil errors
|
||||
// for now.
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user