Merge branch 'master' into 4403-internal-proxy

This commit is contained in:
Ainar Garipov
2022-08-29 15:27:44 +03:00
241 changed files with 15591 additions and 5972 deletions

View File

@@ -1,32 +1,45 @@
// Package aghalg contains common generic algorithms and data structures.
//
// TODO(a.garipov): Update to use type parameters in Go 1.18.
// TODO(a.garipov): Move parts of this into golibs.
package aghalg
import (
"fmt"
"sort"
"golang.org/x/exp/constraints"
"golang.org/x/exp/slices"
)
// comparable is an alias for interface{}. Values passed as arguments of this
// type alias must be comparable.
//
// TODO(a.garipov): Remove in Go 1.18.
type comparable = interface{}
// Coalesce returns the first non-zero value. It is named after the function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// UniqChecker allows validating uniqueness of comparable items.
type UniqChecker map[comparable]int64
//
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
// Consider ways of making this constraint comparable instead.
type UniqChecker[T constraints.Ordered] map[T]int64
// Add adds a value to the validator. v must not be nil.
func (uc UniqChecker) Add(elems ...comparable) {
func (uc UniqChecker[T]) Add(elems ...T) {
for _, e := range elems {
uc[e]++
}
}
// Merge returns a checker containing data from both uc and other.
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
merged = make(UniqChecker, len(uc)+len(other))
func (uc UniqChecker[T]) Merge(other UniqChecker[T]) (merged UniqChecker[T]) {
merged = make(UniqChecker[T], len(uc)+len(other))
for elem, num := range uc {
merged[elem] += num
}
@@ -39,10 +52,8 @@ func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
}
// Validate returns an error enumerating all elements that aren't unique.
// isBefore is an optional sorting function to make the error message
// deterministic.
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
var dup []comparable
func (uc UniqChecker[T]) Validate() (err error) {
var dup []T
for elem, num := range uc {
if num > 1 {
dup = append(dup, elem)
@@ -53,23 +64,7 @@ func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err
return nil
}
if isBefore != nil {
sort.Slice(dup, func(i, j int) (less bool) {
return isBefore(dup[i], dup[j])
})
}
slices.Sort(dup)
return fmt.Errorf("duplicated values: %v", dup)
}
// IntIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type int.
func IntIsBefore(a, b comparable) (less bool) {
return a.(int) < b.(int)
}
// StringIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type string.
func StringIsBefore(a, b comparable) (less bool) {
return a.(string) < b.(string)
}

View File

@@ -0,0 +1,67 @@
package aghalg
import (
"bytes"
"encoding/json"
"fmt"
)
// NullBool is a nullable boolean. Use these in JSON requests and responses
// instead of pointers to bool.
type NullBool uint8
// NullBool values
const (
NBNull NullBool = iota
NBTrue
NBFalse
)
// String implements the fmt.Stringer interface for NullBool.
func (nb NullBool) String() (s string) {
switch nb {
case NBNull:
return "null"
case NBTrue:
return "true"
case NBFalse:
return "false"
}
return fmt.Sprintf("!invalid NullBool %d", uint8(nb))
}
// BoolToNullBool converts a bool into a NullBool.
func BoolToNullBool(cond bool) (nb NullBool) {
if cond {
return NBTrue
}
return NBFalse
}
// type check
var _ json.Marshaler = NBNull
// MarshalJSON implements the json.Marshaler interface for NullBool.
func (nb NullBool) MarshalJSON() (b []byte, err error) {
return []byte(nb.String()), nil
}
// type check
var _ json.Unmarshaler = (*NullBool)(nil)
// UnmarshalJSON implements the json.Unmarshaler interface for *NullBool.
func (nb *NullBool) UnmarshalJSON(b []byte) (err error) {
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
*nb = NBNull
} else if bytes.Equal(b, []byte("true")) {
*nb = NBTrue
} else if bytes.Equal(b, []byte("false")) {
*nb = NBFalse
} else {
return fmt.Errorf("unmarshalling json data into aghalg.NullBool: bad value %q", b)
}
return nil
}

View File

@@ -0,0 +1,113 @@
package aghalg_test
import (
"encoding/json"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNullBool_MarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want []byte
in aghalg.NullBool
}{{
name: "null",
wantErrMsg: "",
want: []byte("null"),
in: aghalg.NBNull,
}, {
name: "true",
wantErrMsg: "",
want: []byte("true"),
in: aghalg.NBTrue,
}, {
name: "false",
wantErrMsg: "",
want: []byte("false"),
in: aghalg.NBFalse,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.in.MarshalJSON()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
in := &struct {
A aghalg.NullBool
}{
A: aghalg.NBTrue,
}
got, err := json.Marshal(in)
require.NoError(t, err)
assert.Equal(t, []byte(`{"A":true}`), got)
})
}
func TestNullBool_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
data []byte
want aghalg.NullBool
}{{
name: "empty",
wantErrMsg: "",
data: []byte{},
want: aghalg.NBNull,
}, {
name: "null",
wantErrMsg: "",
data: []byte("null"),
want: aghalg.NBNull,
}, {
name: "true",
wantErrMsg: "",
data: []byte("true"),
want: aghalg.NBTrue,
}, {
name: "false",
wantErrMsg: "",
data: []byte("false"),
want: aghalg.NBFalse,
}, {
name: "invalid",
wantErrMsg: `unmarshalling json data into aghalg.NullBool: bad value "invalid"`,
data: []byte("invalid"),
want: aghalg.NBNull,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got aghalg.NullBool
err := got.UnmarshalJSON(tc.data)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
want := aghalg.NBTrue
var got struct {
A aghalg.NullBool
}
err := json.Unmarshal([]byte(`{"A":true}`), &got)
require.NoError(t, err)
assert.Equal(t, want, got.A)
})
}

View File

@@ -9,6 +9,12 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// RegisterFunc is the function that sets the handler to handle the URL for the
// method.
//
// TODO(e.burkov, a.garipov): Get rid of it.
type RegisterFunc func(method, url string, handler http.HandlerFunc)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {
@@ -17,7 +23,7 @@ func OK(w http.ResponseWriter) {
}
// Error writes formatted message to w and also logs it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)

View File

@@ -2,13 +2,11 @@ package aghnet
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
)
@@ -27,15 +25,8 @@ type ARPDB interface {
}
// NewARPDB returns the ARPDB properly initialized for the OS.
func NewARPDB() (arp ARPDB, err error) {
arp = newARPDB()
err = arp.Refresh()
if err != nil {
return nil, fmt.Errorf("arpdb initial refresh: %w", err)
}
return arp, nil
func NewARPDB() (arp ARPDB) {
return newARPDB()
}
// Empty ARPDB implementation
@@ -123,50 +114,33 @@ func (ns *neighs) reset(with []Neighbor) {
// of Neighbors.
type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor)
// runCmdFunc is the function that runs some command and returns its output
// wrapped to be a io.Reader.
type runCmdFunc func() (r io.Reader, err error)
// cmdARPDB is the implementation of the ARPDB that uses command line to
// retrieve data.
type cmdARPDB struct {
parse parseNeighsFunc
runcmd runCmdFunc
ns *neighs
parse parseNeighsFunc
ns *neighs
cmd string
args []string
}
// type check
var _ ARPDB = (*cmdARPDB)(nil)
// runCmd runs the cmd with it's args and returns the result wrapped to be an
// io.Reader. The error is returned either if the exit code retured by command
// not equals 0 or the execution itself failed.
func runCmd(cmd string, args ...string) (r io.Reader, err error) {
var code int
var out string
code, out, err = aghos.RunCommand(cmd, args...)
if err != nil {
return nil, err
} else if code != 0 {
return nil, fmt.Errorf("unexpected exit code %d", code)
}
return strings.NewReader(out), nil
}
// Refresh implements the ARPDB interface for *cmdARPDB.
func (arp *cmdARPDB) Refresh() (err error) {
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
var r io.Reader
r, err = arp.runcmd()
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
if err != nil {
return fmt.Errorf("running command: %w", err)
} else if code != 0 {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
sc := bufio.NewScanner(r)
sc := bufio.NewScanner(bytes.NewReader(out))
ns := arp.parse(sc, arp.ns.len())
if err = sc.Err(); err != nil {
// TODO(e.burkov): This error seems unreachable. Investigate.
return fmt.Errorf("scanning the output: %w", err)
}
@@ -187,8 +161,7 @@ func (arp *cmdARPDB) Neighbors() (ns []Neighbor) {
type arpdbs struct {
// arps is the set of ARPDB implementations to range through.
arps []ARPDB
// last is the last succeeded ARPDB index.
last int
neighs
}
// newARPDBs returns a properly initialized *arpdbs. It begins refreshing from
@@ -196,7 +169,10 @@ type arpdbs struct {
func newARPDBs(arps ...ARPDB) (arp *arpdbs) {
return &arpdbs{
arps: arps,
last: 0,
neighs: neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
}
@@ -206,20 +182,18 @@ var _ ARPDB = (*arpdbs)(nil)
// Refresh implements the ARPDB interface for *arpdbs.
func (arp *arpdbs) Refresh() (err error) {
var errs []error
l := len(arp.arps)
// Start from the last succeeded implementation.
for i := 0; i < l; i++ {
cur := (arp.last + i) % l
err = arp.arps[cur].Refresh()
if err == nil {
// The succeeded implementation found so update the last succeeded
// index.
arp.last = cur
return nil
for _, a := range arp.arps {
err = a.Refresh()
if err != nil {
errs = append(errs, err)
continue
}
errs = append(errs, err)
arp.reset(a.Neighbors())
return nil
}
if len(errs) > 0 {
@@ -230,10 +204,8 @@ func (arp *arpdbs) Refresh() (err error) {
}
// Neighbors implements the ARPDB interface for *arpdbs.
//
// TODO(e.burkov): Think of a way to avoid cloning the slice twice.
func (arp *arpdbs) Neighbors() (ns []Neighbor) {
if l := len(arp.arps); l > 0 && arp.last < l {
return arp.arps[arp.last].Neighbors()
}
return nil
return arp.clone()
}

View File

@@ -13,18 +13,24 @@ import (
"github.com/AdguardTeam/golibs/netutil"
)
func newARPDB() *cmdARPDB {
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
parse: parseArpA,
runcmd: rcArpA,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors. By
// default ARP attempts to resolve the hostnames via DNS. See man 8
// arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
}
}
// parseArpA parses the output of the "arp -a" command on macOS and FreeBSD.
// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD.
// The expected input format:
//
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]

View File

@@ -8,8 +8,12 @@ import (
)
const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet]
? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet]
`
var wantNeighs = []Neighbor{{
@@ -20,4 +24,8 @@ var wantNeighs = []Neighbor{{
Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}}

View File

@@ -6,7 +6,6 @@ package aghnet
import (
"bufio"
"fmt"
"io"
"io/fs"
"net"
"strings"
@@ -34,11 +33,30 @@ func newARPDB() (arp *arpdbs) {
return newARPDBs(
// Try /proc/net/arp first.
&fsysARPDB{ns: ns, fsys: aghos.RootDirFS(), filename: "proc/net/arp"},
// Try "arp -a" then.
&cmdARPDB{parse: parseF, runcmd: rcArpA, ns: ns},
// Try "ip neigh" finally.
&cmdARPDB{parse: parseIPNeigh, runcmd: rcIPNeigh, ns: ns},
&fsysARPDB{
ns: ns,
fsys: rootDirFS,
filename: "proc/net/arp",
},
// Then, try "arp -a -n".
&cmdARPDB{
parse: parseF,
ns: ns,
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors.
// By default ARP attempts to resolve the hostnames via DNS. See
// man 8 arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
},
// Finally, try "ip neigh".
&cmdARPDB{
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
},
)
}
@@ -96,11 +114,11 @@ func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
return arp.ns.clone()
}
// parseArpAWrt parses the output of the "arp -a" command on OpenWrt. The
// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The
// expected input format:
//
// IP address HW type Flags HW address Mask Device
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
// IP address HW type Flags HW address Mask Device
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
//
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if !sc.Scan() {
@@ -140,8 +158,8 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
return ns
}
// parseArpA parses the output of the "arp -a" command on Linux. The expected
// input format:
// parseArpA parses the output of the "arp -a -n" command on Linux. The
// expected input format:
//
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
//
@@ -187,11 +205,6 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
return ns
}
// rcIPNeigh runs "ip neigh".
func rcIPNeigh() (r io.Reader, err error) {
return runCmd("ip", "neigh")
}
// parseIPNeigh parses the output of the "ip neigh" command on Linux. The
// expected input format:
//

View File

@@ -4,11 +4,10 @@
package aghnet
import (
"io"
"net"
"strings"
"sync"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -16,14 +15,21 @@ import (
const arpAOutputWrt = `
IP address HW type Flags HW address Mask Device
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan`
const arpAOutput = `
invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet]
invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet]
invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet]
? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet]
? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]`
const ipNeighOutput = `
1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY
1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY
192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
@@ -36,6 +42,8 @@ var wantNeighs = []Neighbor{{
}}
func TestFSysARPDB(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "proc_net_arp"))
a := &fsysARPDB{
ns: &neighs{
mu: &sync.RWMutex{},
@@ -52,33 +60,43 @@ func TestFSysARPDB(t *testing.T) {
assert.Equal(t, wantNeighs, ns)
}
func TestCmdARPDB_arpawrt(t *testing.T) {
a := &cmdARPDB{
parse: parseArpAWrt,
runcmd: func() (r io.Reader, err error) { return strings.NewReader(arpAOutputWrt), nil },
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
func TestCmdARPDB_linux(t *testing.T) {
sh := mapShell{
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
}
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
t.Run("wrt", func(t *testing.T) {
a := &cmdARPDB{
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
assert.Equal(t, wantNeighs, a.Neighbors())
}
func TestCmdARPDB_ipneigh(t *testing.T) {
a := &cmdARPDB{
parse: parseIPNeigh,
runcmd: func() (r io.Reader, err error) { return strings.NewReader(ipNeighOutput), nil },
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
t.Run("ip_neigh", func(t *testing.T) {
a := &cmdARPDB{
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
}

View File

@@ -12,19 +12,25 @@ import (
"github.com/AdguardTeam/golibs/log"
)
func newARPDB() *cmdARPDB {
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
runcmd: rcArpA,
parse: parseArpA,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors. By
// default ARP attempts to resolve the hostnames via DNS. See man 8
// arp.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/3157.
args: []string{"-a", "-n"},
}
}
// parseArpA parses the output of the "arp -a" command on OpenBSD. The expected
// input format:
// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The
// expected input format:
//
// Host Ethernet Address Netif Expire Flags
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s

View File

@@ -9,6 +9,8 @@ import (
const arpAOutput = `
Host Ethernet Address Netif Expire Flags
1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent
1.2.3.4 12:34:56:78:910 em0 permanent
192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s
::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l
`

View File

@@ -1,9 +1,7 @@
package aghnet
import (
"io"
"net"
"strings"
"sync"
"testing"
@@ -13,6 +11,13 @@ import (
"github.com/stretchr/testify/require"
)
func TestNewARPDB(t *testing.T) {
var a ARPDB
require.NotPanics(t, func() { a = NewARPDB() })
assert.NotNil(t, a)
}
// TestARPDB is the mock implementation of ARPDB to use in tests.
type TestARPDB struct {
OnRefresh func() (err error)
@@ -125,11 +130,11 @@ func TestARPDBS(t *testing.T) {
assert.Equal(t, 1, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
// Only the last succeeded ARPDB should be used.
// Unstable ARPDB should refresh successfully again.
err = a.Refresh()
require.NoError(t, err)
assert.Equal(t, 2, succRefrCount)
assert.Equal(t, 1, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
})
@@ -143,6 +148,7 @@ func TestARPDBS(t *testing.T) {
func TestCmdARPDB_arpa(t *testing.T) {
a := &cmdARPDB{
cmd: "cmd",
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
@@ -151,7 +157,8 @@ func TestCmdARPDB_arpa(t *testing.T) {
}
t.Run("arp_a", func(t *testing.T) {
a.runcmd = func() (r io.Reader, err error) { return strings.NewReader(arpAOutput), nil }
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
@@ -160,9 +167,50 @@ func TestCmdARPDB_arpa(t *testing.T) {
})
t.Run("runcmd_error", func(t *testing.T) {
a.runcmd = func() (r io.Reader, err error) { return nil, errors.Error("can't run") }
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
substShell(t, sh.RunCmd)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
})
t.Run("bad_code", func(t *testing.T) {
sh := theOnlyCmd("cmd", 1, "", nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
})
t.Run("empty", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
assert.Empty(t, a.Neighbors())
})
}
func TestEmptyARPDB(t *testing.T) {
a := EmptyARPDB{}
t.Run("refresh", func(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = a.Refresh()
})
assert.NoError(t, err)
})
t.Run("neighbors", func(t *testing.T) {
var ns []Neighbor
require.NotPanics(t, func() {
ns = a.Neighbors()
})
assert.Empty(t, ns)
})
}

View File

@@ -1,13 +0,0 @@
//go:build !windows
// +build !windows
package aghnet
import (
"io"
)
// rcArpA runs "arp -a".
func rcArpA() (r io.Reader, err error) {
return runCmd("arp", "-a")
}

View File

@@ -5,28 +5,23 @@ package aghnet
import (
"bufio"
"io"
"net"
"strings"
"sync"
)
func newARPDB() *cmdARPDB {
func newARPDB() (arp *cmdARPDB) {
return &cmdARPDB{
runcmd: rcArpA,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
parse: parseArpA,
cmd: "arp",
args: []string{"/a"},
}
}
// rcArpA runs "arp /a".
func rcArpA() (r io.Reader, err error) {
return runCmd("arp", "/a")
}
// parseArpA parses the output of the "arp /a" command on Windows. The expected
// input format (the first line is empty):
//

View File

@@ -156,7 +156,7 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
b := make([]byte, 1500)
n, _, err := c.ReadFrom(b)
if err != nil {
if isTimeout(err) {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv4: didn't receive dhcp response")
return false, false, nil
@@ -176,20 +176,21 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
log.Debug("dhcpv4: received message from server: %s", response.Summary())
if !(response.OpCode == dhcpv4.OpcodeBootReply &&
response.HWType == iana.HWTypeEthernet &&
bytes.Equal(response.ClientHWAddr, iface.HardwareAddr) &&
bytes.Equal(response.TransactionID[:], req.TransactionID[:]) &&
response.Options.Has(dhcpv4.OptionDHCPMessageType)) {
log.Debug("dhcpv4: received message from server doesn't match our request")
switch {
case
response.OpCode != dhcpv4.OpcodeBootReply,
response.HWType != iana.HWTypeEthernet,
!bytes.Equal(response.ClientHWAddr, iface.HardwareAddr),
response.TransactionID != req.TransactionID,
!response.Options.Has(dhcpv4.OptionDHCPMessageType):
log.Debug("dhcpv4: received response doesn't match the request")
return false, true, nil
default:
log.Tracef("dhcpv4: the packet is from an active dhcp server")
return true, false, nil
}
log.Tracef("dhcpv4: the packet is from an active dhcp server")
return true, false, nil
}
// checkOtherDHCPv6 sends a DHCP request to the specified network interface, and
@@ -275,7 +276,7 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
n, _, err := c.ReadFrom(b)
if err != nil {
if isTimeout(err) {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv6: didn't receive dhcp response")
return false, false, nil
@@ -318,15 +319,3 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
return true, false, nil
}
// isTimeout returns true if err is an operation timeout error from net package.
//
// TODO(e.burkov): Consider moving into netutil.
func isTimeout(err error) (ok bool) {
var operr *net.OpError
if errors.As(err, &operr) {
return operr.Timeout()
}
return false
}

View File

@@ -11,7 +11,7 @@ const (
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
)
// generateIPv4Hostname generates the hostname for specific IP version.
// generateIPv4Hostname generates the hostname by IP address version 4.
func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv4HostnameMaxLen)
for i, part := range ipv4 {
@@ -24,7 +24,7 @@ func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
return string(hnData)
}
// generateIPv6Hostname generates the hostname for specific IP version.
// generateIPv6Hostname generates the hostname by IP address version 6.
func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv6HostnameMaxLen)
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
@@ -51,12 +51,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
//
// ff80-f076-0000-0000-0000-0000-0000-0010
//
// ip must be either an IPv4 or an IPv6.
func GenerateHostname(ip net.IP) (hostname string) {
if ipv4 := ip.To4(); ipv4 != nil {
return generateIPv4Hostname(ipv4)
} else if ipv6 := ip.To16(); ipv6 != nil {
return generateIPv6Hostname(ipv6)
}
return ""
return generateIPv6Hostname(ip)
}

View File

@@ -8,41 +8,57 @@ import (
)
func TestGenerateHostName(t *testing.T) {
testCases := []struct {
name string
want string
ip net.IP
}{{
name: "good_ipv4",
want: "127-0-0-1",
ip: net.IP{127, 0, 0, 1},
}, {
name: "bad_ipv4",
want: "",
ip: net.IP{127, 0, 0, 1, 0},
}, {
name: "good_ipv6",
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
ip: net.ParseIP("fe00::1"),
}, {
name: "bad_ipv6",
want: "",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
want: "",
ip: nil,
}}
t.Run("valid", func(t *testing.T) {
testCases := []struct {
name string
want string
ip net.IP
}{{
name: "good_ipv4",
want: "127-0-0-1",
ip: net.IP{127, 0, 0, 1},
}, {
name: "good_ipv6",
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
ip: net.ParseIP("fe00::1"),
}, {
name: "4to6",
want: "1-2-3-4",
ip: net.ParseIP("::ffff:1.2.3.4"),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hostname := GenerateHostname(tc.ip)
assert.Equal(t, tc.want, hostname)
})
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hostname := GenerateHostname(tc.ip)
assert.Equal(t, tc.want, hostname)
})
}
})
t.Run("invalid", func(t *testing.T) {
testCases := []struct {
name string
ip net.IP
}{{
name: "bad_ipv4",
ip: net.IP{127, 0, 0, 1, 0},
}, {
name: "bad_ipv6",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
ip: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Panics(t, func() { GenerateHostname(tc.ip) })
})
}
})
}

View File

@@ -51,7 +51,7 @@ type requestMatcher struct {
//
// It's safe for concurrent use.
func (rm *requestMatcher) MatchRequest(
req urlfilter.DNSRequest,
req *urlfilter.DNSRequest,
) (res *urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
@@ -198,7 +198,7 @@ func (hc *HostsContainer) Close() (err error) {
}
// Upd returns the channel into which the updates are sent. The receivable
// map's values are guaranteed to be of type of *stringutil.Set.
// map's values are guaranteed to be of type of *HostsRecord.
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
return hc.updates
}
@@ -290,7 +290,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
continue
}
hp.addPairs(ip, hosts)
hp.addRecord(ip, hosts)
}
return nil, true, s.Err()
@@ -335,41 +335,68 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
return ip, hosts
}
// addPair puts the pair of ip and host to the rules builder if needed. For
// each ip the first member of hosts will become the main one.
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
// HostsRecord represents a single hosts file record.
type HostsRecord struct {
Aliases *stringutil.Set
Canonical string
}
// Equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
}
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
}
// addRecord puts the record for the IP address to the rules builder if needed.
// The first host is considered to be the canonical name for the IP address.
// hosts must have at least one name.
func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
var rec *HostsRecord
v, ok := hp.table.Get(ip)
if !ok {
// This ip is added at the first time.
v = stringutil.NewSet()
hp.table.Set(ip, v)
rec = &HostsRecord{
Aliases: stringutil.NewSet(),
}
rec.Canonical, hosts = hosts[0], hosts[1:]
hp.addRules(ip, rec.Canonical, line)
hp.table.Set(ip, rec)
} else {
rec, ok = v.(*HostsRecord)
if !ok {
log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v)
return
}
}
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
if !ok {
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
return
}
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
for _, h := range hosts {
if set.Has(h) {
for _, host := range hosts {
if rec.Canonical == host || rec.Aliases.Has(host) {
continue
}
set.Add(h)
rec.Aliases.Add(host)
rule, rulePtr := hp.writeRules(h, ip)
hp.translations[rule], hp.translations[rulePtr] = processed, processed
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
hp.addRules(ip, host, line)
}
}
// writeRules writes the actual rule for the qtype and the PTR for the
// host-ip pair into internal builders.
// addRules adds rules and rule translations for the line.
func (hp *hostsParser) addRules(ip net.IP, host, line string) {
rule, rulePtr := hp.writeRules(host, ip)
hp.translations[rule], hp.translations[rulePtr] = line, line
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
}
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
// pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip)
if err != nil {
@@ -417,6 +444,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
}
// equalSet returns true if the internal hosts table just parsed equals target.
// target's values must be of type *HostsRecord.
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
if target == nil {
// hp.table shouldn't appear nil since it's initialized on each refresh.
@@ -427,22 +455,35 @@ func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
return false
}
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
// ok is set to true if the target doesn't contain ip or if the
// appropriate hosts set isn't equal to the checked one.
if a, hasIP := target.Get(ip); !hasIP {
ok = true
} else if hosts, aok := a.(*stringutil.Set); aok {
ok = !hosts.Equal(b.(*stringutil.Set))
hp.table.Range(func(ip net.IP, recVal any) (cont bool) {
var targetVal any
targetVal, ok = target.Get(ip)
if !ok {
return false
}
// Continue only if maps has no discrepancies.
return !ok
var rec *HostsRecord
rec, ok = recVal.(*HostsRecord)
if !ok {
log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal)
return false
}
var targetRec *HostsRecord
targetRec, ok = targetVal.(*HostsRecord)
if !ok {
log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal)
return false
}
ok = rec.Equal(targetRec)
return ok
})
// Return true if every value from the IP map has no discrepancies with the
// appropriate one from the target.
return !ok
return ok
}
// sendUpd tries to send the parsed data to the ch.

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter"
@@ -159,31 +160,47 @@ func TestHostsContainer_refresh(t *testing.T) {
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
upd, ok := <-hc.Upd()
require.True(t, ok)
require.NotNil(t, upd)
checkRefresh := func(t *testing.T, want *HostsRecord) {
t.Helper()
var ok bool
var upd *netutil.IPMap
select {
case upd, ok = <-hc.Upd():
require.True(t, ok)
require.NotNil(t, upd)
case <-time.After(1 * time.Second):
t.Fatal("did not receive after 1s")
}
assert.Equal(t, 1, upd.Len())
v, ok := upd.Get(ip)
require.True(t, ok)
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
require.True(t, ok)
require.IsType(t, (*HostsRecord)(nil), v)
assert.True(t, set.Equal(wantHosts))
rec, _ := v.(*HostsRecord)
require.NotNil(t, rec)
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, stringutil.NewSet("hostname"))
checkRefresh(t, &HostsRecord{
Aliases: stringutil.NewSet(),
Canonical: "hostname",
})
})
t.Run("second_refresh", func(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
checkRefresh(t, &HostsRecord{
Aliases: stringutil.NewSet("alias"),
Canonical: "hostname",
})
})
t.Run("double_refresh", func(t *testing.T) {
@@ -193,7 +210,7 @@ func TestHostsContainer_refresh(t *testing.T) {
// Require the changes are written.
require.Eventually(t, func() bool {
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
res, ok := hc.MatchRequest(&urlfilter.DNSRequest{
Hostname: "hostname",
DNSType: dns.TypeA,
})
@@ -207,7 +224,7 @@ func TestHostsContainer_refresh(t *testing.T) {
// Require the changes are written.
require.Eventually(t, func() bool {
res, ok := hc.MatchRequest(urlfilter.DNSRequest{
res, ok := hc.MatchRequest(&urlfilter.DNSRequest{
Hostname: "hostname",
DNSType: dns.TypeA,
})
@@ -286,6 +303,8 @@ func TestHostsContainer_Translate(t *testing.T) {
OnClose: func() (err error) { panic("not implemented") },
}
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
@@ -358,11 +377,18 @@ func TestHostsContainer_Translate(t *testing.T) {
func TestHostsContainer(t *testing.T) {
const listID = 1234
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
testCases := []struct {
want []*rules.DNSRewrite
req *urlfilter.DNSRequest
name string
req urlfilter.DNSRequest
want []*rules.DNSRewrite
}{{
req: &urlfilter.DNSRequest{
Hostname: "simplehost",
DNSType: dns.TypeA,
},
name: "simple",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 1),
@@ -372,27 +398,12 @@ func TestHostsContainer(t *testing.T) {
Value: net.ParseIP("::1"),
RRType: dns.TypeAAAA,
}},
name: "simple",
req: urlfilter.DNSRequest{
Hostname: "simplehost",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "hello_alias",
req: urlfilter.DNSRequest{
req: &urlfilter.DNSRequest{
Hostname: "hello.world",
DNSType: dns.TypeA,
},
}, {
name: "hello_alias",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
@@ -402,26 +413,41 @@ func TestHostsContainer(t *testing.T) {
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "other_line_alias",
req: urlfilter.DNSRequest{
}, {
req: &urlfilter.DNSRequest{
Hostname: "hello.world.again",
DNSType: dns.TypeA,
},
name: "other_line_alias",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
}, {
want: []*rules.DNSRewrite{},
name: "hello_subdomain",
req: urlfilter.DNSRequest{
req: &urlfilter.DNSRequest{
Hostname: "say.hello",
DNSType: dns.TypeA,
},
}, {
name: "hello_subdomain",
want: []*rules.DNSRewrite{},
name: "hello_alias_subdomain",
req: urlfilter.DNSRequest{
}, {
req: &urlfilter.DNSRequest{
Hostname: "say.hello.world",
DNSType: dns.TypeA,
},
name: "hello_alias_subdomain",
want: []*rules.DNSRewrite{},
}, {
req: &urlfilter.DNSRequest{
Hostname: "for.testing",
DNSType: dns.TypeA,
},
name: "lots_of_aliases",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -431,37 +457,37 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::2"),
}},
name: "lots_of_aliases",
req: urlfilter.DNSRequest{
Hostname: "for.testing",
DNSType: dns.TypeA,
},
}, {
req: &urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR,
},
name: "reverse",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypePTR,
Value: "simplehost.",
}},
name: "reverse",
req: urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR,
},
}, {
want: []*rules.DNSRewrite{},
name: "non-existing",
req: urlfilter.DNSRequest{
Hostname: "nonexisting",
req: &urlfilter.DNSRequest{
Hostname: "nonexistent.example",
DNSType: dns.TypeA,
},
name: "non-existing",
want: []*rules.DNSRewrite{},
}, {
want: nil,
name: "bad_type",
req: urlfilter.DNSRequest{
req: &urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypeSRV,
},
name: "bad_type",
want: nil,
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain",
DNSType: dns.TypeA,
},
name: "issue_4216_4_6",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -471,12 +497,12 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::42"),
}},
name: "issue_4216_4_6",
req: urlfilter.DNSRequest{
Hostname: "domain",
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain4",
DNSType: dns.TypeA,
},
}, {
name: "issue_4216_4",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -486,12 +512,12 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeA,
Value: net.IPv4(1, 3, 5, 7),
}},
name: "issue_4216_4",
req: urlfilter.DNSRequest{
Hostname: "domain4",
DNSType: dns.TypeA,
},
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain6",
DNSType: dns.TypeAAAA,
},
name: "issue_4216_6",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
@@ -501,11 +527,6 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::31"),
}},
name: "issue_4216_6",
req: urlfilter.DNSRequest{
Hostname: "domain6",
DNSType: dns.TypeAAAA,
},
}}
stubWatcher := aghtest.FSWatcher{

View File

@@ -2,19 +2,31 @@
package aghnet
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"os/exec"
"strings"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
// netInterfaces is the function to get the available network interfaces.
netInterfaceAddrs = net.InterfaceAddrs
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = aghos.RootDirFS()
)
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
// the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
@@ -32,39 +44,29 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
}
// GatewayIP returns IP address of interface's gateway.
func GatewayIP(ifaceName string) net.IP {
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
//
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed.
func GatewayIP(ifaceName string) (ip net.IP) {
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
if err != nil {
log.Debug("%s", err)
return nil
} else if code != 0 {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
return nil
}
fields := strings.Fields(string(d))
fields := bytes.Fields(out)
// The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || fields[0] != "default" {
if len(fields) < 3 || string(fields[0]) != "default" {
return nil
}
return net.ParseIP(fields[2])
}
// CanBindPort checks if we can bind to the given port.
func CanBindPort(port int) (can bool, err error) {
var addr *net.TCPAddr
addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return false, err
}
var listener *net.TCPListener
listener, err = net.ListenTCP("tcp", addr)
if err != nil {
return false, err
}
_ = listener.Close()
return true, nil
return net.ParseIP(string(fields[2]))
}
// CanBindPrivilegedPorts checks if current process can bind to privileged
@@ -99,19 +101,19 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
})
}
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only
// we do not return link-local addresses here
func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// WEB only we do not return link-local addresses here.
//
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
}
if len(ifaces) == 0 {
} else if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface")
}
var netInterfaces []*NetInterface
for _, iface := range ifaces {
var addrs []net.Addr
addrs, err = iface.Addrs()
@@ -131,27 +133,34 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
ipNet, ok := addr.(*net.IPNet)
if !ok {
// Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
}
// Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface)
netIfaces = append(netIfaces, netIface)
}
}
return netInterfaces, nil
return netIfaces, nil
}
// GetInterfaceByIP returns the name of interface containing provided ip.
func GetInterfaceByIP(ip net.IP) string {
// InterfaceByIP returns the name of the interface bound to ip.
//
// TODO(a.garipov, e.burkov): This function is technically incorrect, since one
// IP address can be shared by multiple interfaces in some configurations.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func InterfaceByIP(ip net.IP) (ifaceName string) {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
return ""
@@ -170,6 +179,8 @@ func GetInterfaceByIP(ip net.IP) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
@@ -220,29 +231,21 @@ func IsAddrInUse(err error) (ok bool) {
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) {
var ifaces []net.Interface
ifaces, err = net.Interfaces()
var ifaceAddrs []net.Addr
ifaceAddrs, err = netInterfaceAddrs()
if err != nil {
return nil, fmt.Errorf("getting network interfaces: %w", err)
return nil, fmt.Errorf("getting interfaces addresses: %w", err)
}
for _, iface := range ifaces {
var ifaceAddrs []net.Addr
ifaceAddrs, err = iface.Addrs()
for _, addr := range ifaceAddrs {
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err)
return nil, fmt.Errorf("parsing cidr: %w", err)
}
for _, addr := range ifaceAddrs {
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("parsing cidr: %w", err)
}
addrs = append(addrs, ip.String())
}
addrs = append(addrs, ip.String())
}
return addrs, nil

View File

@@ -0,0 +1,10 @@
//go:build darwin || freebsd || openbsd
// +build darwin freebsd openbsd
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}

View File

@@ -4,10 +4,11 @@
package aghnet
import (
"bufio"
"bytes"
"fmt"
"os"
"io"
"regexp"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
@@ -23,11 +24,7 @@ type hardwarePortInfo struct {
static bool
}
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (bool, error) {
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil {
return false, err
@@ -36,9 +33,10 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) {
return portInfo.static, nil
}
// getCurrentHardwarePortInfo gets information for the specified network interface.
// getCurrentHardwarePortInfo gets information for the specified network
// interface.
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// First of all we should find hardware port name
// First of all we should find hardware port name.
m := getNetworkSetupHardwareReports()
hardwarePort, ok := m[ifaceName]
if !ok {
@@ -48,6 +46,10 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
return getHardwarePortInfo(hardwarePort)
}
// hardwareReportsReg is the regular expression matching the lines of
// networksetup command output lines containing the interface information.
var hardwareReportsReg = regexp.MustCompile("Hardware Port: (.*?)\nDevice: (.*?)\n")
// getNetworkSetupHardwareReports parses the output of the `networksetup
// -listallhardwareports` command it returns a map where the key is the
// interface name, and the value is the "hardware port" returns nil if it fails
@@ -56,54 +58,44 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// TODO(e.burkov): There should be more proper approach than parsing the
// command output. For example, see
// https://developer.apple.com/documentation/systemconfiguration.
func getNetworkSetupHardwareReports() map[string]string {
_, out, err := aghos.RunCommand("networksetup", "-listallhardwareports")
func getNetworkSetupHardwareReports() (reports map[string]string) {
_, out, err := aghosRunCommand("networksetup", "-listallhardwareports")
if err != nil {
return nil
}
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n")
if err != nil {
return nil
reports = make(map[string]string)
matches := hardwareReportsReg.FindAllSubmatch(out, -1)
for _, m := range matches {
reports[string(m[2])] = string(m[1])
}
m := make(map[string]string)
matches := re.FindAllStringSubmatch(out, -1)
for i := range matches {
port := matches[i][1]
device := matches[i][2]
m[device] = port
}
return m
return reports
}
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
h := hardwarePortInfo{}
// hardwarePortReg is the regular expression matching the lines of networksetup
// command output lines containing the port information.
var hardwarePortReg = regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
_, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort)
func getHardwarePortInfo(hardwarePort string) (h hardwarePortInfo, err error) {
_, out, err := aghosRunCommand("networksetup", "-getinfo", hardwarePort)
if err != nil {
return h, err
}
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
match := re.FindStringSubmatch(out)
if len(match) == 0 {
match := hardwarePortReg.FindSubmatch(out)
if len(match) != 4 {
return h, errors.Error("could not find hardware port info")
}
h.name = hardwarePort
h.ip = match[1]
h.subnet = match[2]
h.gatewayIP = match[3]
if strings.Index(out, "Manual Configuration") == 0 {
h.static = true
}
return h, nil
return hardwarePortInfo{
name: hardwarePort,
ip: string(match[1]),
subnet: string(match[2]),
gatewayIP: string(match[3]),
static: bytes.Index(out, []byte("Manual Configuration")) == 0,
}, nil
}
func ifaceSetStaticIP(ifaceName string) (err error) {
@@ -113,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
}
if portInfo.static {
return errors.Error("IP address is already static")
return errors.Error("ip address is already static")
}
dnsAddrs, err := getEtcResolvConfServers()
@@ -121,50 +113,62 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
return err
}
args := make([]string, 0)
args = append(args, "-setdnsservers", portInfo.name)
args = append(args, dnsAddrs...)
args := append([]string{"-setdnsservers", portInfo.name}, dnsAddrs...)
// Setting DNS servers is necessary when configuring a static IP
code, _, err := aghos.RunCommand("networksetup", args...)
code, _, err := aghosRunCommand("networksetup", args...)
if err != nil {
return err
}
if code != 0 {
} else if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code)
}
// Actually configures hardware port to have static IP
code, _, err = aghos.RunCommand("networksetup", "-setmanual",
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
code, _, err = aghosRunCommand(
"networksetup",
"-setmanual",
portInfo.name,
portInfo.ip,
portInfo.subnet,
portInfo.gatewayIP,
)
if err != nil {
return err
}
if code != 0 {
} else if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code)
}
return nil
}
// etcResolvConfReg is the regular expression matching the lines of resolv.conf
// file containing a name server information.
var etcResolvConfReg = regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
// getEtcResolvConfServers returns a list of nameservers configured in
// /etc/resolv.conf.
func getEtcResolvConfServers() ([]string, error) {
body, err := os.ReadFile("/etc/resolv.conf")
func getEtcResolvConfServers() (addrs []string, err error) {
const filename = "etc/resolv.conf"
_, err = aghos.FileWalker(func(r io.Reader) (_ []string, _ bool, err error) {
sc := bufio.NewScanner(r)
for sc.Scan() {
matches := etcResolvConfReg.FindAllStringSubmatch(sc.Text(), -1)
if len(matches) == 0 {
continue
}
for _, m := range matches {
addrs = append(addrs, m[1])
}
}
return nil, false, sc.Err()
}).Walk(rootDirFS, filename)
if err != nil {
return nil, err
}
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
matches := re.FindAllStringSubmatch(string(body), -1)
if len(matches) == 0 {
return nil, errors.Error("found no DNS servers in /etc/resolv.conf")
}
addrs := make([]string, 0)
for i := range matches {
addrs = append(addrs, matches[i][1])
return nil, fmt.Errorf("parsing etc/resolv.conf file: %w", err)
} else if len(addrs) == 0 {
return nil, fmt.Errorf("found no dns servers in %s", filename)
}
return addrs, nil

View File

@@ -0,0 +1,261 @@
package aghnet
import (
"io/fs"
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestIfaceHasStaticIP(t *testing.T) {
testCases := []struct {
name string
shell mapShell
ifaceName string
wantHas assert.BoolAssertionFunc
wantErrMsg string
}{{
name: "success",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: ``,
}, {
name: "success_static",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.True,
wantErrMsg: ``,
}, {
name: "reports_error",
shell: theOnlyCmd(
"networksetup -listallhardwareports",
0,
"",
errors.Error("can't list"),
),
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `could not find hardware port for en0`,
}, {
name: "port_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: errors.Error("can't get"),
out: ``,
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `can't get`,
}, {
name: "port_bad_output",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "nothing meaningful",
code: 0,
},
},
ifaceName: "en0",
wantHas: assert.False,
wantErrMsg: `could not find hardware port info`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
has, err := IfaceHasStaticIP(tc.ifaceName)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
tc.wantHas(t, has)
})
}
}
func TestIfaceSetStaticIP(t *testing.T) {
succFsys := fstest.MapFS{
"etc/resolv.conf": &fstest.MapFile{
Data: []byte(`nameserver 1.1.1.1`),
},
}
panicFsys := &aghtest.FS{
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
}
testCases := []struct {
name string
shell mapShell
fsys fs.FS
wantErrMsg string
}{{
name: "success",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: nil,
out: "",
code: 0,
},
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
err: nil,
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: ``,
}, {
name: "static_already",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "Manual Configuration\nIP address: 1.2.3.4\n" +
"Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
fsys: panicFsys,
wantErrMsg: `ip address is already static`,
}, {
name: "reports_error",
shell: theOnlyCmd(
"networksetup -listallhardwareports",
0,
"",
errors.Error("can't list"),
),
fsys: panicFsys,
wantErrMsg: `could not find hardware port for en0`,
}, {
name: "resolv_conf_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
},
fsys: fstest.MapFS{
"etc/resolv.conf": &fstest.MapFile{
Data: []byte("this resolv.conf is invalid"),
},
},
wantErrMsg: `found no dns servers in etc/resolv.conf`,
}, {
name: "set_dns_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: errors.Error("can't set"),
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: `can't set`,
}, {
name: "set_manual_error",
shell: mapShell{
"networksetup -listallhardwareports": {
err: nil,
out: "Hardware Port: hwport\nDevice: en0\n",
code: 0,
},
"networksetup -getinfo hwport": {
err: nil,
out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n",
code: 0,
},
"networksetup -setdnsservers hwport 1.1.1.1": {
err: nil,
out: "",
code: 0,
},
"networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": {
err: errors.Error("can't set"),
out: "",
code: 0,
},
},
fsys: succFsys,
wantErrMsg: `can't set`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
substRootDirFS(t, tc.fsys)
err := IfaceSetStaticIP("en0")
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@@ -13,16 +13,12 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const rcConfFilename = "etc/rc.conf"
walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig)
return walker.Walk(aghos.RootDirFS(), rcConfFilename)
return walker.Walk(rootDirFS, rcConfFilename)
}
// rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to

View File

@@ -4,56 +4,74 @@
package aghnet
import (
"strings"
"io/fs"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRcConfStaticConfig(t *testing.T) {
const iface interfaceName = `em0`
const nl = "\n"
func TestIfaceHasStaticIP(t *testing.T) {
const (
ifaceName = `em0`
rcConf = "etc/rc.conf"
)
testCases := []struct {
name string
rcconfData string
wantCont bool
name string
rootFsys fs.FS
wantHas assert.BoolAssertionFunc
}{{
name: "simple",
rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantCont: false,
name: "simple",
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl),
}},
wantHas: assert.True,
}, {
name: "case_insensitiveness",
rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl,
wantCont: false,
name: "case_insensitiveness",
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(`ifconfig_` + ifaceName + `="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl),
}},
wantHas: assert.True,
}, {
name: "comments_and_trash",
rcconfData: `# comment 1` + nl +
`` + nl +
`# comment 2` + nl +
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantCont: false,
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(`# comment 1` + nl +
`` + nl +
`# comment 2` + nl +
`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
),
}},
wantHas: assert.True,
}, {
name: "aliases",
rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl +
`ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl,
wantCont: false,
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(`ifconfig_` + ifaceName + `_alias="inet 127.0.0.1/24"` + nl +
`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl,
),
}},
wantHas: assert.True,
}, {
name: "incorrect_config",
rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
`ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl +
`ifconfig_em0=""` + nl,
wantCont: true,
rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{
Data: []byte(
`ifconfig_` + ifaceName + `="inet6 127.0.0.253 netmask 0xffffffff"` + nl +
`ifconfig_` + ifaceName + `="inet 256.256.256.256 netmask 0xffffffff"` + nl +
`ifconfig_` + ifaceName + `=""` + nl,
),
}},
wantHas: assert.False,
}}
for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) {
_, cont, err := iface.rcConfStaticConfig(r)
substRootDirFS(t, tc.rootFsys)
has, err := IfaceHasStaticIP(ifaceName)
require.NoError(t, err)
assert.Equal(t, tc.wantCont, cont)
tc.wantHas(t, has)
})
}
}

View File

@@ -13,16 +13,44 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe"
"golang.org/x/sys/unix"
)
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
res, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_IS_SET,
unix.CAP_NET_BIND_SERVICE,
0,
0,
)
if err != nil {
if errors.Is(err, unix.EINVAL) {
// Older versions of Linux kernel do not support this. Print a
// warning and check admin rights.
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
} else {
return false, err
}
}
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return res == 1 || adm, nil
}
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP.
func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) {
s := bufio.NewScanner(r)
ifaceFound := findIfaceLine(s, string(n))
if !ifaceFound {
if !findIfaceLine(s, string(n)) {
return nil, true, s.Err()
}
@@ -61,9 +89,9 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
fields := strings.Fields(line)
fieldsNum := len(fields)
// Man page interfaces(5) declares that interface definition
// should consist of the key word "iface" followed by interface
// name, and method at fourth field.
// Man page interfaces(5) declares that interface definition should
// consist of the key word "iface" followed by interface name, and
// method at fourth field.
if fieldsNum >= 4 &&
fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" {
return nil, false, nil
@@ -78,10 +106,10 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool,
}
func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
// TODO(a.garipov): Currently, this function returns the first
// definitive result. So if /etc/dhcpcd.conf has a static IP while
// /etc/network/interfaces doesn't, it will return true. Perhaps this
// is not the most desirable behavior.
// TODO(a.garipov): Currently, this function returns the first definitive
// result. So if /etc/dhcpcd.conf has and /etc/network/interfaces has no
// static IP configuration, it will return true. Perhaps this is not the
// most desirable behavior.
iface := interfaceName(ifaceName)
@@ -90,17 +118,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
filename string
}{{
FileWalker: iface.dhcpcdStaticConfig,
filename: "etc/dhcpcd.conf",
filename: dhcpcdConf,
}, {
FileWalker: iface.ifacesStaticConfig,
filename: "etc/network/interfaces",
}} {
has, err = pair.Walk(aghos.RootDirFS(), pair.filename)
has, err = pair.Walk(rootDirFS, pair.filename)
if err != nil {
return false, err
}
if has {
} else if has {
return true, nil
}
}
@@ -108,14 +134,6 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, ErrNoStaticIPInfo
}
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, unix.CAP_NET_BIND_SERVICE, 0, 0)
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
}
// findIfaceLine scans s until it finds the line that declares an interface with
// the given name. If findIfaceLine can't find the line, it returns false.
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
@@ -131,23 +149,23 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
}
// ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpdc.conf.
// interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil {
return errors.Error("can't get IP address")
}
gatewayIP := GatewayIP(ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP)
body, err := os.ReadFile("/etc/dhcpcd.conf")
body, err := os.ReadFile(dhcpcdConf)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
gatewayIP := GatewayIP(ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
body = append(body, []byte(add)...)
err = maybe.WriteFile("/etc/dhcpcd.conf", body, 0o644)
err = maybe.WriteFile(dhcpcdConf, body, 0o644)
if err != nil {
return fmt.Errorf("writing conf: %w", err)
}
@@ -157,22 +175,24 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gatewayIP, dnsIP net.IP) (conf string) {
var body []byte
add := fmt.Sprintf(
"\n# %[1]s added by AdGuard Home.\ninterface %[1]s\nstatic ip_address=%s\n",
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
b := &strings.Builder{}
stringutil.WriteToBuilder(
b,
"\n# ",
ifaceName,
ipNet)
body = append(body, []byte(add)...)
" added by AdGuard Home.\ninterface ",
ifaceName,
"\nstatic ip_address=",
ipNet.String(),
"\n",
)
if gatewayIP != nil {
add = fmt.Sprintf("static routers=%s\n", gatewayIP)
body = append(body, []byte(add)...)
if gwIP != nil {
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
}
add = fmt.Sprintf("static domain_name_servers=%s\n\n", dnsIP)
body = append(body, []byte(add)...)
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
return string(body)
return b.String()
}

View File

@@ -4,152 +4,124 @@
package aghnet
import (
"bytes"
"net"
"io/fs"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDHCPCDStaticConfig(t *testing.T) {
const iface interfaceName = `wlan0`
func TestHasStaticIP(t *testing.T) {
const ifaceName = "wlan0"
const (
dhcpcd = "etc/dhcpcd.conf"
netifaces = "etc/network/interfaces"
)
testCases := []struct {
name string
data []byte
wantCont bool
}{{
name: "has_not",
data: []byte(`#comment` + nl +
`# comment` + nl +
`interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl +
`# interface ` + iface + nl +
`static ip_address=192.168.1.1/24` + nl +
`# comment` + nl,
),
wantCont: true,
}, {
name: "has",
data: []byte(`#comment` + nl +
`# comment` + nl +
`interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl +
`# interface ` + iface + nl +
`static ip_address=192.168.1.1/24` + nl +
`# comment` + nl +
`interface ` + iface + nl +
`# comment` + nl +
`static ip_address=192.168.2.1/24` + nl,
),
wantCont: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := bytes.NewReader(tc.data)
_, cont, err := iface.dhcpcdStaticConfig(r)
require.NoError(t, err)
assert.Equal(t, tc.wantCont, cont)
})
}
}
func TestIfacesStaticConfig(t *testing.T) {
const iface interfaceName = `enp0s3`
testCases := []struct {
name string
data []byte
wantCont bool
wantPatterns []string
}{{
name: "has_not",
data: []byte(`allow-hotplug ` + iface + nl +
`#iface enp0s3 inet static` + nl +
`# address 192.168.0.200` + nl +
`# netmask 255.255.255.0` + nl +
`# gateway 192.168.0.1` + nl +
`iface ` + iface + ` inet dhcp` + nl,
),
wantCont: true,
wantPatterns: []string{},
}, {
name: "has",
data: []byte(`allow-hotplug ` + iface + nl +
`iface ` + iface + ` inet static` + nl +
` address 192.168.0.200` + nl +
` netmask 255.255.255.0` + nl +
` gateway 192.168.0.1` + nl +
`#iface ` + iface + ` inet dhcp` + nl,
),
wantCont: false,
wantPatterns: []string{},
}, {
name: "return_patterns",
data: []byte(`source hello` + nl +
`source world` + nl +
`#iface ` + iface + ` inet static` + nl,
),
wantCont: true,
wantPatterns: []string{"hello", "world"},
}, {
// This one tests if the first found valid interface prevents
// checking files under the `source` directive.
name: "ignore_patterns",
data: []byte(`source hello` + nl +
`source world` + nl +
`iface ` + iface + ` inet static` + nl,
),
wantCont: false,
wantPatterns: []string{},
}}
for _, tc := range testCases {
r := bytes.NewReader(tc.data)
t.Run(tc.name, func(t *testing.T) {
patterns, has, err := iface.ifacesStaticConfig(r)
require.NoError(t, err)
assert.Equal(t, tc.wantCont, has)
assert.ElementsMatch(t, tc.wantPatterns, patterns)
})
}
}
func TestSetStaticIPdhcpcdConf(t *testing.T) {
testCases := []struct {
rootFsys fs.FS
name string
dhcpcdConf string
routers net.IP
wantHas assert.BoolAssertionFunc
wantErrMsg string
}{{
name: "with_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static routers=192.168.0.1` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: net.IP{192, 168, 0, 1},
rootFsys: fstest.MapFS{
dhcpcd: &fstest.MapFile{
Data: []byte(`#comment` + nl +
`# comment` + nl +
`interface eth0` + nl +
`static ip_address=192.168.0.1/24` + nl +
`# interface ` + ifaceName + nl +
`static ip_address=192.168.1.1/24` + nl +
`# comment` + nl,
),
},
},
name: "dhcpcd_has_not",
wantHas: assert.False,
wantErrMsg: `no information about static ip`,
}, {
name: "without_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: nil,
rootFsys: fstest.MapFS{
dhcpcd: &fstest.MapFile{
Data: []byte(`#comment` + nl +
`# comment` + nl +
`interface ` + ifaceName + nl +
`static ip_address=192.168.0.1/24` + nl +
`# interface ` + ifaceName + nl +
`static ip_address=192.168.1.1/24` + nl +
`# comment` + nl,
),
},
},
name: "dhcpcd_has",
wantHas: assert.True,
wantErrMsg: ``,
}, {
rootFsys: fstest.MapFS{
netifaces: &fstest.MapFile{
Data: []byte(`allow-hotplug ` + ifaceName + nl +
`#iface enp0s3 inet static` + nl +
`# address 192.168.0.200` + nl +
`# netmask 255.255.255.0` + nl +
`# gateway 192.168.0.1` + nl +
`iface ` + ifaceName + ` inet dhcp` + nl,
),
},
},
name: "netifaces_has_not",
wantHas: assert.False,
wantErrMsg: `no information about static ip`,
}, {
rootFsys: fstest.MapFS{
netifaces: &fstest.MapFile{
Data: []byte(`allow-hotplug ` + ifaceName + nl +
`iface ` + ifaceName + ` inet static` + nl +
` address 192.168.0.200` + nl +
` netmask 255.255.255.0` + nl +
` gateway 192.168.0.1` + nl +
`#iface ` + ifaceName + ` inet dhcp` + nl,
),
},
},
name: "netifaces_has",
wantHas: assert.True,
wantErrMsg: ``,
}, {
rootFsys: fstest.MapFS{
netifaces: &fstest.MapFile{
Data: []byte(`source hello` + nl +
`#iface ` + ifaceName + ` inet static` + nl,
),
},
"hello": &fstest.MapFile{
Data: []byte(`iface ` + ifaceName + ` inet static` + nl),
},
},
name: "netifaces_another_file",
wantHas: assert.True,
wantErrMsg: ``,
}, {
rootFsys: fstest.MapFS{
netifaces: &fstest.MapFile{
Data: []byte(`source hello` + nl +
`iface ` + ifaceName + ` inet static` + nl,
),
},
},
name: "netifaces_ignore_another",
wantHas: assert.True,
wantErrMsg: ``,
}}
ipNet := &net.IPNet{
IP: net.IP{192, 168, 0, 2},
Mask: net.IPMask{255, 255, 255, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := dhcpcdConfIface("wlan0", ipNet, tc.routers, net.IP{192, 168, 0, 2})
assert.Equal(t, tc.dhcpcdConf, s)
substRootDirFS(t, tc.rootFsys)
has, err := IfaceHasStaticIP(ifaceName)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
tc.wantHas(t, has)
})
}
}

View File

@@ -13,14 +13,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("etc/hostname.%s", ifaceName)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename)
return aghos.FileWalker(hostnameIfStaticConfig).Walk(rootDirFS, filename)
}
// hostnameIfStaticConfig checks if the interface is configured by

View File

@@ -4,49 +4,69 @@
package aghnet
import (
"strings"
"fmt"
"io/fs"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostnameIfStaticConfig(t *testing.T) {
const nl = "\n"
func TestIfaceHasStaticIP(t *testing.T) {
const ifaceName = "em0"
confFile := fmt.Sprintf("etc/hostname.%s", ifaceName)
testCases := []struct {
name string
rcconfData string
wantHas bool
name string
rootFsys fs.FS
wantHas assert.BoolAssertionFunc
}{{
name: "simple",
rcconfData: `inet 127.0.0.253` + nl,
wantHas: true,
name: "simple",
rootFsys: fstest.MapFS{
confFile: &fstest.MapFile{
Data: []byte(`inet 127.0.0.253` + nl),
},
},
wantHas: assert.True,
}, {
name: "case_sensitiveness",
rcconfData: `InEt 127.0.0.253` + nl,
wantHas: false,
name: "case_sensitiveness",
rootFsys: fstest.MapFS{
confFile: &fstest.MapFile{
Data: []byte(`InEt 127.0.0.253` + nl),
},
},
wantHas: assert.False,
}, {
name: "comments_and_trash",
rcconfData: `# comment 1` + nl +
`` + nl +
`# inet 127.0.0.253` + nl +
`inet` + nl,
wantHas: false,
rootFsys: fstest.MapFS{
confFile: &fstest.MapFile{
Data: []byte(`# comment 1` + nl + nl +
`# inet 127.0.0.253` + nl +
`inet` + nl,
),
},
},
wantHas: assert.False,
}, {
name: "incorrect_config",
rcconfData: `inet6 127.0.0.253` + nl +
`inet 256.256.256.256` + nl,
wantHas: false,
rootFsys: fstest.MapFS{
confFile: &fstest.MapFile{
Data: []byte(`inet6 127.0.0.253` + nl + `inet 256.256.256.256` + nl),
},
},
wantHas: assert.False,
}}
for _, tc := range testCases {
r := strings.NewReader(tc.rcconfData)
t.Run(tc.name, func(t *testing.T) {
_, has, err := hostnameIfStaticConfig(r)
substRootDirFS(t, tc.rootFsys)
has, err := IfaceHasStaticIP(ifaceName)
require.NoError(t, err)
assert.Equal(t, tc.wantHas, has)
tc.wantHas(t, has)
})
}
}

View File

@@ -1,26 +1,138 @@
package aghnet
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"os"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestGetInterfaceByIP(t *testing.T) {
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
name string
shell mapShell
want net.IP
}{{
name: "success_v4",
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(),
}, {
name: "success_v6",
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0xFF, 0xFF,
},
}, {
name: "bad_output",
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil,
}, {
name: "err_runcmd",
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil,
}, {
name: "bad_code",
shell: theOnlyCmd(cmd, 1, "", nil),
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
@@ -30,7 +142,7 @@ func TestGetInterfaceByIP(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := GetInterfaceByIP(ip)
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
@@ -130,3 +242,107 @@ func TestCheckPort(t *testing.T) {
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
iface := &NetInterface{
Addresses: []net.IP{ip4, ip6},
Subnets: []*net.IPNet{{
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@@ -1,5 +1,5 @@
//go:build !(linux || darwin || freebsd || openbsd)
// +build !linux,!darwin,!freebsd,!openbsd
//go:build windows
// +build windows
package aghnet
@@ -14,7 +14,7 @@ import (
)
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
return true, nil
}
func ifaceHasStaticIP(string) (ok bool, err error) {

View File

@@ -1,158 +0,0 @@
package aghnet
import (
"net"
)
// SubnetDetector describes IP address properties.
type SubnetDetector struct {
// spNets is the collection of special-purpose address registries as defined
// by RFC 6890.
spNets []*net.IPNet
// locServedNets is the collection of locally-served networks as defined by
// RFC 6303.
locServedNets []*net.IPNet
}
// NewSubnetDetector returns a new IP detector.
//
// TODO(a.garipov): Decide whether an error is actually needed.
func NewSubnetDetector() (snd *SubnetDetector, err error) {
spNets := []string{
// "This" network.
"0.0.0.0/8",
// Private-Use Networks.
"10.0.0.0/8",
// Shared Address Space.
"100.64.0.0/10",
// Loopback.
"127.0.0.0/8",
// Link Local.
"169.254.0.0/16",
// Private-Use Networks.
"172.16.0.0/12",
// IETF Protocol Assignments.
"192.0.0.0/24",
// DS-Lite.
"192.0.0.0/29",
// TEST-NET-1
"192.0.2.0/24",
// 6to4 Relay Anycast.
"192.88.99.0/24",
// Private-Use Networks.
"192.168.0.0/16",
// Network Interconnect Device Benchmark Testing.
"198.18.0.0/15",
// TEST-NET-2.
"198.51.100.0/24",
// TEST-NET-3.
"203.0.113.0/24",
// Reserved for Future Use.
"240.0.0.0/4",
// Limited Broadcast.
"255.255.255.255/32",
// Loopback.
"::1/128",
// Unspecified.
"::/128",
// IPv4-IPv6 Translation Address.
"64:ff9b::/96",
// IPv4-Mapped Address. Since this network is used for mapping
// IPv4 addresses, we don't include it.
// "::ffff:0:0/96",
// Discard-Only Prefix.
"100::/64",
// IETF Protocol Assignments.
"2001::/23",
// TEREDO.
"2001::/32",
// Benchmarking.
"2001:2::/48",
// Documentation.
"2001:db8::/32",
// ORCHID.
"2001:10::/28",
// 6to4.
"2002::/16",
// Unique-Local.
"fc00::/7",
// Linked-Scoped Unicast.
"fe80::/10",
}
// TODO(e.burkov): It's a subslice of the slice above. Should be done
// smarter.
locServedNets := []string{
// IPv4.
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"169.254.0.0/16",
"192.0.2.0/24",
"198.51.100.0/24",
"203.0.113.0/24",
"255.255.255.255/32",
// IPv6.
"::/128",
"::1/128",
"fe80::/10",
"2001:db8::/32",
"fd00::/8",
}
snd = &SubnetDetector{
spNets: make([]*net.IPNet, len(spNets)),
locServedNets: make([]*net.IPNet, len(locServedNets)),
}
for i, ipnetStr := range spNets {
var ipnet *net.IPNet
_, ipnet, err = net.ParseCIDR(ipnetStr)
if err != nil {
return nil, err
}
snd.spNets[i] = ipnet
}
for i, ipnetStr := range locServedNets {
var ipnet *net.IPNet
_, ipnet, err = net.ParseCIDR(ipnetStr)
if err != nil {
return nil, err
}
snd.locServedNets[i] = ipnet
}
return snd, nil
}
// anyNetContains ranges through the given ipnets slice searching for the one
// which contains the ip. For internal use only.
//
// TODO(e.burkov): Think about memoization.
func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) {
for _, ipnet := range *ipnets {
if ipnet.Contains(ip) {
return true
}
}
return false
}
// IsSpecialNetwork returns true if IP address is contained by any of
// special-purpose IP address registries. It's safe for concurrent use.
func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) {
return anyNetContains(&snd.spNets, ip)
}
// IsLocallyServedNetwork returns true if IP address is contained by any of
// locally-served IP address registries. It's safe for concurrent use.
func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) {
return anyNetContains(&snd.locServedNets, ip)
}

View File

@@ -1,252 +0,0 @@
package aghnet
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) {
snd, err := NewSubnetDetector()
require.NoError(t, err)
testCases := []struct {
name string
ip net.IP
want bool
}{{
name: "not_specific",
ip: net.ParseIP("8.8.8.8"),
want: false,
}, {
name: "this_host_on_this_network",
ip: net.ParseIP("0.0.0.0"),
want: true,
}, {
name: "private-Use",
ip: net.ParseIP("10.0.0.0"),
want: true,
}, {
name: "shared_address_space",
ip: net.ParseIP("100.64.0.0"),
want: true,
}, {
name: "loopback",
ip: net.ParseIP("127.0.0.0"),
want: true,
}, {
name: "link_local",
ip: net.ParseIP("169.254.0.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("172.16.0.0"),
want: true,
}, {
name: "ietf_protocol_assignments",
ip: net.ParseIP("192.0.0.0"),
want: true,
}, {
name: "ds-lite",
ip: net.ParseIP("192.0.0.0"),
want: true,
}, {
name: "documentation_(test-net-1)",
ip: net.ParseIP("192.0.2.0"),
want: true,
}, {
name: "6to4_relay_anycast",
ip: net.ParseIP("192.88.99.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("192.168.0.0"),
want: true,
}, {
name: "benchmarking",
ip: net.ParseIP("198.18.0.0"),
want: true,
}, {
name: "documentation_(test-net-2)",
ip: net.ParseIP("198.51.100.0"),
want: true,
}, {
name: "documentation_(test-net-3)",
ip: net.ParseIP("203.0.113.0"),
want: true,
}, {
name: "reserved",
ip: net.ParseIP("240.0.0.0"),
want: true,
}, {
name: "limited_broadcast",
ip: net.ParseIP("255.255.255.255"),
want: true,
}, {
name: "loopback_address",
ip: net.ParseIP("::1"),
want: true,
}, {
name: "unspecified_address",
ip: net.ParseIP("::"),
want: true,
}, {
name: "ipv4-ipv6_translation",
ip: net.ParseIP("64:ff9b::"),
want: true,
}, {
name: "discard-only_address_block",
ip: net.ParseIP("100::"),
want: true,
}, {
name: "ietf_protocol_assignments",
ip: net.ParseIP("2001::"),
want: true,
}, {
name: "teredo",
ip: net.ParseIP("2001::"),
want: true,
}, {
name: "benchmarking",
ip: net.ParseIP("2001:2::"),
want: true,
}, {
name: "documentation",
ip: net.ParseIP("2001:db8::"),
want: true,
}, {
name: "orchid",
ip: net.ParseIP("2001:10::"),
want: true,
}, {
name: "6to4",
ip: net.ParseIP("2002::"),
want: true,
}, {
name: "unique-local",
ip: net.ParseIP("fc00::"),
want: true,
}, {
name: "linked-scoped_unicast",
ip: net.ParseIP("fe80::"),
want: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip))
})
}
}
func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) {
snd, err := NewSubnetDetector()
require.NoError(t, err)
testCases := []struct {
name string
ip net.IP
want bool
}{{
name: "not_specific",
ip: net.ParseIP("8.8.8.8"),
want: false,
}, {
name: "private-Use",
ip: net.ParseIP("10.0.0.0"),
want: true,
}, {
name: "loopback",
ip: net.ParseIP("127.0.0.0"),
want: true,
}, {
name: "link_local",
ip: net.ParseIP("169.254.0.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("172.16.0.0"),
want: true,
}, {
name: "documentation_(test-net-1)",
ip: net.ParseIP("192.0.2.0"),
want: true,
}, {
name: "private-use",
ip: net.ParseIP("192.168.0.0"),
want: true,
}, {
name: "documentation_(test-net-2)",
ip: net.ParseIP("198.51.100.0"),
want: true,
}, {
name: "documentation_(test-net-3)",
ip: net.ParseIP("203.0.113.0"),
want: true,
}, {
name: "limited_broadcast",
ip: net.ParseIP("255.255.255.255"),
want: true,
}, {
name: "loopback_address",
ip: net.ParseIP("::1"),
want: true,
}, {
name: "unspecified_address",
ip: net.ParseIP("::"),
want: true,
}, {
name: "documentation",
ip: net.ParseIP("2001:db8::"),
want: true,
}, {
name: "linked-scoped_unicast",
ip: net.ParseIP("fe80::"),
want: true,
}, {
name: "locally_assigned",
ip: net.ParseIP("fd00::1"),
want: true,
}, {
name: "not_locally_assigned",
ip: net.ParseIP("fc00::1"),
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip))
})
}
}
func TestSubnetDetector_Detect_parallel(t *testing.T) {
t.Parallel()
snd, err := NewSubnetDetector()
require.NoError(t, err)
testFunc := func() {
for _, ip := range []net.IP{
net.IPv4allrouter,
net.IPv4allsys,
net.IPv4bcast,
net.IPv4zero,
net.IPv6interfacelocalallnodes,
net.IPv6linklocalallnodes,
net.IPv6linklocalallrouters,
net.IPv6loopback,
net.IPv6unspecified,
} {
_ = snd.IsSpecialNetwork(ip)
_ = snd.IsLocallyServedNetwork(ip)
}
}
const goroutinesNum = 50
for i := 0; i < goroutinesNum; i++ {
go testFunc()
}
}

View File

@@ -1,11 +1,5 @@
package aghnet
import (
"time"
"github.com/AdguardTeam/golibs/log"
)
// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
// const DefaultRefreshIvl = 5 * time.Minute
@@ -16,39 +10,21 @@ type HostGenFunc func() (host string)
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It should be
// safe for concurrent use.
// Get returns the slice of local resolvers' addresses. It must be safe for
// concurrent use.
Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It should be
// safe for concurrent use.
// refresh refreshes the local resolvers' addresses cache. It must be safe
// for concurrent use.
refresh() (err error)
}
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer log.OnPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := sr.refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)
continue
}
log.Debug("systemResolvers: local addresses cache is refreshed")
}
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) {
sr = newSystemResolvers(refreshIvl, hostGenFunc)
sr = newSystemResolvers(hostGenFunc)
// Fill cache.
err = sr.refresh()
@@ -56,11 +32,5 @@ func NewSystemResolvers(
return nil, err
}
if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)
go refreshWithTicker(sr, ticker.C)
}
return sr, nil
}

View File

@@ -24,12 +24,15 @@ func defaultHostGen() (host string) {
// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
resolver *net.Resolver
hostGenFunc HostGenFunc
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// addrsLock protects addrs.
addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
}
const (
@@ -44,6 +47,7 @@ const (
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
@@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) {
return err
}
func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) {
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}
@@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
var ipStr string
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
ipStr = host
// host
case 2:
// Remove the zone and check the IP address part.
ipStr = parts[0]
host = parts[0]
default:
return errUnexpectedHostFormat
}
if net.ParseIP(ipStr) == nil {
if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed
}

View File

@@ -6,37 +6,32 @@ package aghnet
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolversImp(
func createTestSystemResolversImpl(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()
sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)
sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)
var ok bool
imp, ok = sr.(*systemResolvers)
require.True(t, ok)
return imp
return sr.(*systemResolvers)
}
func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)
sr := createTestSystemResolvers(t, nil)
assert.NoError(t, sr.refresh())
})
t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(0, func() string {
_, err := NewSystemResolvers(func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
@@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) {
}
func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil)
imp := createTestSystemResolversImpl(t, nil)
testCases := []struct {
want error
@@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
address string
}{{
want: errFakeDial,
name: "valid",
name: "valid_ipv4",
address: "127.0.0.1",
}, {
want: errFakeDial,

View File

@@ -2,7 +2,6 @@ package aghnet
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -10,13 +9,12 @@ import (
func createTestSystemResolvers(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers) {
t.Helper()
var err error
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
sr, err = NewSystemResolvers(hostGenFunc)
require.NoError(t, err)
require.NotNil(t, sr)
@@ -24,8 +22,14 @@ func createTestSystemResolvers(
}
func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)
assert.NotEmpty(t, sr.Get())
sr := createTestSystemResolvers(t, nil)
var rs []string
require.NotPanics(t, func() {
rs = sr.Get()
})
assert.NotEmpty(t, rs)
}
// TODO(e.burkov): Write tests for refreshWithTicker.

View File

@@ -11,7 +11,6 @@ import (
"os/exec"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -27,7 +26,7 @@ type systemResolvers struct {
addrsLock sync.RWMutex
}
func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) {
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{}
}

View File

@@ -1,4 +1,6 @@
IP address HW type Flags HW address Mask Device
192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan
::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec
1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan
1.2.3.4 0x1 0x2 12:34:56:78:910 * wan

View File

@@ -1,4 +1,4 @@
package aghos
package aghos_test
import (
"testing"

View File

@@ -0,0 +1,57 @@
package aghos
import (
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// errFS is an fs.FS implementation, method Open of which always returns
// errFSOpen.
type errFS struct{}
// errFSOpen is returned from errGlobFS.Open.
const errFSOpen errors.Error = "test open error"
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
// err is always errFSOpen.
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
return nil, errFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@@ -1,13 +1,13 @@
package aghos
package aghos_test
import (
"bufio"
"io"
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -16,7 +16,7 @@ import (
func TestFileWalker_Walk(t *testing.T) {
const attribute = `000`
makeFileWalker := func(_ string) (fw FileWalker) {
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
@@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
patterns = append(patterns, s.Text())
@@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr
}).Walk(f, "*")
require.ErrorIs(t, err, rerr)
@@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
assert.False(t, ok)
})
}
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@@ -57,20 +57,22 @@ func HaveAdminRights() (bool, error) {
const MaxCmdOutputSize = 64 * 1024
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (code int, output string, err error) {
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
if len(out) > MaxCmdOutputSize {
out = out[:MaxCmdOutputSize]
}
if errors.As(err, new(*exec.ExitError)) {
return cmd.ProcessState.ExitCode(), string(out), nil
} else if err != nil {
return 1, "", fmt.Errorf("command %q failed: %w: %s", command, err, out)
if err != nil {
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
return eerr.ExitCode(), eerr.Stderr, nil
}
return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
}
return cmd.ProcessState.ExitCode(), string(out), nil
return cmd.ProcessState.ExitCode(), out, nil
}
// PIDByCommand searches for process named command and returns its PID ignoring
@@ -173,3 +175,13 @@ func RootDirFS() (fsys fs.FS) {
// behavior is undocumented but it currently works.
return os.DirFS("")
}
// NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)
}

27
internal/aghos/os_unix.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build darwin || freebsd || linux || openbsd
// +build darwin freebsd linux openbsd
package aghos
import (
"os"
"os/signal"
"golang.org/x/sys/unix"
)
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
unix.SIGINT,
unix.SIGQUIT,
unix.SIGTERM:
return true
default:
return false
}
}

View File

@@ -4,6 +4,10 @@
package aghos
import (
"os"
"os/signal"
"syscall"
"golang.org/x/sys/windows"
)
@@ -35,3 +39,20 @@ func haveAdminRights() (bool, error) {
func isOpenWrt() (ok bool) {
return false
}
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
os.Interrupt,
syscall.SIGTERM:
return true
default:
return false
}
}

View File

@@ -1,20 +0,0 @@
package aghtest
import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Exchanger is a mock aghnet.Exchanger implementation for tests.
type Exchanger struct {
Ups upstream.Upstream
}
// Exchange implements aghnet.Exchanger interface for *Exchanger.
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if e.Ups == nil {
e.Ups = &TestErrUpstream{}
}
return e.Ups.Exchange(req)
}

View File

@@ -1,23 +0,0 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@@ -0,0 +1,135 @@
package aghtest
import (
"io/fs"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Interface Mocks
//
// Keep entities in this file in alphabetic order.
// Standard Library
// type check
var _ fs.FS = &FS{}
// FS is a mock [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ net.Listener = (*Listener)(nil)
// Listener is a mock [net.Listener] implementation for tests.
type Listener struct {
OnAccept func() (conn net.Conn, err error)
OnAddr func() (addr net.Addr)
OnClose func() (err error)
}
// Accept implements the [net.Listener] interface for *Listener.
func (l *Listener) Accept() (conn net.Conn, err error) {
return l.OnAccept()
}
// Addr implements the [net.Listener] interface for *Listener.
func (l *Listener) Addr() (addr net.Addr) {
return l.OnAddr()
}
// Close implements the [net.Listener] interface for *Listener.
func (l *Listener) Close() (err error) {
return l.OnClose()
}
// Module dnsproxy
// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Module AdGuardHome
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@@ -0,0 +1,9 @@
package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
)
// type check
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)

View File

@@ -1,46 +0,0 @@
package aghtest
import "io/fs"
// type check
var _ fs.FS = &FS{}
// FS is a mock fs.FS implementation to use in tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the fs.FS interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock fs.StatFS implementation to use in tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the fs.StatFS interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock fs.GlobFS implementation to use in tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the fs.GlobFS interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

View File

@@ -6,12 +6,18 @@ import (
"fmt"
"net"
"strings"
"sync"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
// Additional Upstream Testing Utilities
// Upstream is a mock implementation of upstream.Upstream.
//
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
type Upstream struct {
// CName is a map of hostname to canonical name.
CName map[string][]string
@@ -25,6 +31,43 @@ type Upstream struct {
Addr string
}
// RespondTo returns a response with answer if req has class cl, question type
// qt, and target targ.
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
t.Helper()
require.NotNil(t, req)
require.Len(t, req.Question, 1)
q := req.Question[0]
targ = dns.Fqdn(targ)
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
return nil
}
respHdr := dns.RR_Header{
Name: targ,
Rrtype: qt,
Class: cl,
Ttl: 60,
}
resp = new(dns.Msg).SetReply(req)
switch qt {
case dns.TypePTR:
resp.Answer = []dns.RR{
&dns.PTR{
Hdr: respHdr,
Ptr: answer,
},
}
default:
t.Fatalf("unsupported question type: %s", dns.Type(qt))
}
return resp
}
// Exchange implements the upstream.Upstream interface for *Upstream.
//
// TODO(a.garipov): Split further into handlers.
@@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
return u.Addr
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
// lock protects reqNum.
lock sync.RWMutex
reqNum int
Block bool
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
// a different hash.
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
hash := sha256.Sum256([]byte(hostname))
hashStr := hex.EncodeToString(hash[:])
if !shouldBlock {
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
m := &dns.Msg{}
m.SetReply(r)
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
ans := &dns.TXT{
Hdr: dns.RR_Header{
Name: "",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
},
Txt: []string{hashStr},
}
respTmpl := &dns.Msg{
Answer: []dns.RR{ans},
}
return &UpstreamMock{
OnAddress: func() (addr string) {
return "sbpc.upstream.example"
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = respTmpl.Copy()
resp.SetReply(req)
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
return resp, nil
},
}
return m, nil
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// ErrUpstream is the error returned from the [*UpstreamMock] created by
// [NewErrorUpstream].
const ErrUpstream errors.Error = "test upstream error"
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.reqNum
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct {
// The error returned by Exchange may be unwrapped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
// its Exchange method.
func NewErrorUpstream() (u *UpstreamMock) {
return &UpstreamMock{
OnAddress: func() (addr string) {
return "error.upstream.example"
},
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("test upstream error")
},
}
}

View File

@@ -16,6 +16,8 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/mdlayher/ethernet"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
@@ -49,16 +51,15 @@ type dhcpConn struct {
}
// newDHCPConn creates the special connection for DHCP server.
func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) {
// Create the raw connection.
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
var ucast net.PacketConn
if ucast, err = raw.ListenPacket(ifi, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
return nil, fmt.Errorf("creating raw udp connection: %w", err)
}
// Create the UDP connection.
var bcast net.PacketConn
bcast, err = server4.NewIPv4UDPConn(ifi.Name, &net.UDPAddr{
bcast, err = server4.NewIPv4UDPConn(iface.Name, &net.UDPAddr{
// TODO(e.burkov): Listening on zeroes makes the server handle
// requests from all the interfaces. Inspect the ways to
// specify the interface-specific listening addresses.
@@ -75,7 +76,7 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error)
udpConn: bcast,
bcastIP: s.conf.broadcastIP,
rawConn: ucast,
srcMAC: ifi.HardwareAddr,
srcMAC: iface.HardwareAddr,
srcIP: s.conf.dnsIPAddrs[0],
}, nil
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
func TestDHCPConn_WriteTo_common(t *testing.T) {

View File

@@ -5,11 +5,11 @@ import (
"encoding/json"
"fmt"
"net"
"net/http"
"path/filepath"
"runtime"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
@@ -126,7 +126,7 @@ type ServerConfig struct {
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`

View File

@@ -10,6 +10,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
@@ -145,7 +146,7 @@ type dhcpServerConfigJSON struct {
V4 *v4ServerConfJSON `json:"v4"`
V6 *v6ServerConfJSON `json:"v6"`
InterfaceName string `json:"interface_name"`
Enabled nullBool `json:"enabled"`
Enabled aghalg.NullBool `json:"enabled"`
}
func (s *Server) handleDHCPSetConfigV4(
@@ -156,7 +157,7 @@ func (s *Server) handleDHCPSetConfigV4(
}
v4Conf := v4JSONToServerConf(conf.V4)
v4Conf.Enabled = conf.Enabled == nbTrue
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
if len(v4Conf.RangeStart) == 0 {
v4Conf.Enabled = false
}
@@ -183,7 +184,7 @@ func (s *Server) handleDHCPSetConfigV6(
}
v6Conf := v6JSONToServerConf(conf.V6)
v6Conf.Enabled = conf.Enabled == nbTrue
v6Conf.Enabled = conf.Enabled == aghalg.NBTrue
if len(v6Conf.RangeStart) == 0 {
v6Conf.Enabled = false
}
@@ -206,7 +207,7 @@ func (s *Server) handleDHCPSetConfigV6(
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
conf := &dhcpServerConfigJSON{}
conf.Enabled = boolToNullBool(s.conf.Enabled)
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
conf.InterfaceName = s.conf.InterfaceName
err := json.NewDecoder(r.Body).Decode(conf)
@@ -230,7 +231,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
if conf.Enabled == aghalg.NBTrue && !v4Enabled && !v6Enabled {
aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
return
@@ -243,8 +244,8 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
if conf.Enabled != nbNull {
s.conf.Enabled = conf.Enabled == nbTrue
if conf.Enabled != aghalg.NBNull {
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
}
if conf.InterfaceName != "" {
@@ -279,11 +280,11 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
type netInterfaceJSON struct {
Name string `json:"name"`
GatewayIP net.IP `json:"gateway_ip"`
HardwareAddr string `json:"hardware_address"`
Flags string `json:"flags"`
GatewayIP net.IP `json:"gateway_ip"`
Addrs4 []net.IP `json:"ipv4_addresses"`
Addrs6 []net.IP `json:"ipv6_addresses"`
Flags string `json:"flags"`
}
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
@@ -497,7 +498,6 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
}
ip4 := l.IP.To4()
if ip4 == nil {
l.IP = l.IP.To16()

View File

@@ -1,58 +0,0 @@
package dhcpd
import (
"bytes"
"fmt"
)
// nullBool is a nullable boolean. Use these in JSON requests and responses
// instead of pointers to bool.
//
// TODO(a.garipov): Inspect uses of *bool, move this type into some new package
// if we need it somewhere else.
type nullBool uint8
// nullBool values
const (
nbNull nullBool = iota
nbTrue
nbFalse
)
// String implements the fmt.Stringer interface for nullBool.
func (nb nullBool) String() (s string) {
switch nb {
case nbNull:
return "null"
case nbTrue:
return "true"
case nbFalse:
return "false"
}
return fmt.Sprintf("!invalid nullBool %d", uint8(nb))
}
// boolToNullBool converts a bool into a nullBool.
func boolToNullBool(cond bool) (nb nullBool) {
if cond {
return nbTrue
}
return nbFalse
}
// UnmarshalJSON implements the json.Unmarshaler interface for *nullBool.
func (nb *nullBool) UnmarshalJSON(b []byte) (err error) {
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
*nb = nbNull
} else if bytes.Equal(b, []byte("true")) {
*nb = nbTrue
} else if bytes.Equal(b, []byte("false")) {
*nb = nbFalse
} else {
return fmt.Errorf("invalid nullBool value %q", b)
}
return nil
}

View File

@@ -1,66 +0,0 @@
package dhcpd
import (
"encoding/json"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNullBool_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
data []byte
want nullBool
}{{
name: "empty",
wantErrMsg: "",
data: []byte{},
want: nbNull,
}, {
name: "null",
wantErrMsg: "",
data: []byte("null"),
want: nbNull,
}, {
name: "true",
wantErrMsg: "",
data: []byte("true"),
want: nbTrue,
}, {
name: "false",
wantErrMsg: "",
data: []byte("false"),
want: nbFalse,
}, {
name: "invalid",
wantErrMsg: `invalid nullBool value "invalid"`,
data: []byte("invalid"),
want: nbNull,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got nullBool
err := got.UnmarshalJSON(tc.data)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
want := nbTrue
var got struct {
A nullBool
}
err := json.Unmarshal([]byte(`{"A":true}`), &got)
require.NoError(t, err)
assert.Equal(t, want, got.A)
})
}

View File

@@ -18,17 +18,14 @@ import (
// The aliases for DHCP option types available for explicit declaration.
const (
hexTyp = "hex"
ipTyp = "ip"
ipsTyp = "ips"
textTyp = "text"
typHex = "hex"
typIP = "ip"
typIPs = "ips"
typText = "text"
typDel = "del"
)
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string. For
// example:
//
// 252 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
//
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string.
func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
var data []byte
data, err = hex.DecodeString(s)
@@ -39,10 +36,7 @@ func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
return dhcpv4.OptionGeneric{Data: data}, nil
}
// parseDHCPOptionIP parses a DHCP option as a single IP address. For example:
//
// 6 ip 192.168.1.1
//
// parseDHCPOptionIP parses a DHCP option as a single IP address.
func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
var ip net.IP
// All DHCPv4 options require IPv4, so don't put the 16-byte version.
@@ -58,10 +52,7 @@ func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
}
// parseDHCPOptionIPs parses a DHCP option as a comma-separates list of IP
// addresses. For example:
//
// 6 ips 192.168.1.1,192.168.1.2
//
// addresses.
func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
var ips dhcpv4.IPs
var ip net.IP
@@ -78,23 +69,53 @@ func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
}
// parseDHCPOptionText parses a DHCP option as a simple UTF-8 encoded
// text. For example:
//
// 252 text http://192.168.1.1/wpad.dat
//
// text.
func parseDHCPOptionText(s string) (val dhcpv4.OptionValue) {
return dhcpv4.OptionGeneric{Data: []byte(s)}
}
// parseDHCPOption parses an option. See the documentation of parseDHCPOption*
// for more info.
// parseDHCPOptionVal parses a DHCP option value considering typ. For the del
// option the value string is ignored. The examples of possible value pairs:
//
// - hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
// - ip 192.168.1.1
// - ips 192.168.1.1,192.168.1.2
// - text http://192.168.1.1/wpad.dat
// - del
func parseDHCPOptionVal(typ, valStr string) (val dhcpv4.OptionValue, err error) {
switch typ {
case typHex:
val, err = parseDHCPOptionHex(valStr)
case typIP:
val, err = parseDHCPOptionIP(valStr)
case typIPs:
val, err = parseDHCPOptionIPs(valStr)
case typText:
val = parseDHCPOptionText(valStr)
case typDel:
val = dhcpv4.OptionGeneric{Data: nil}
default:
err = fmt.Errorf("unknown option type %q", typ)
}
return val, err
}
// parseDHCPOption parses an option. See the documentation of
// parseDHCPOptionVal for more info.
func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
s = strings.TrimSpace(s)
parts := strings.SplitN(s, " ", 3)
if len(parts) < 3 {
return opt, errors.Error("need at least three fields")
var valStr string
if pl := len(parts); pl < 3 {
if pl < 2 || parts[1] != typDel {
return opt, errors.Error("bad option format")
}
} else {
valStr = parts[2]
}
var code64 uint64
@@ -103,27 +124,16 @@ func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
return opt, fmt.Errorf("parsing option code: %w", err)
}
var optVal dhcpv4.OptionValue
switch typ, val := parts[1], parts[2]; typ {
case hexTyp:
optVal, err = parseDHCPOptionHex(val)
case ipTyp:
optVal, err = parseDHCPOptionIP(val)
case ipsTyp:
optVal, err = parseDHCPOptionIPs(val)
case textTyp:
optVal = parseDHCPOptionText(val)
default:
return opt, fmt.Errorf("unknown option type %q", typ)
}
val, err := parseDHCPOptionVal(parts[1], valStr)
if err != nil {
// Don't wrap an error since it's informative enough as is and there
// also the deferred annotation.
return opt, err
}
return dhcpv4.Option{
Code: dhcpv4.GenericOptionCode(code64),
Value: optVal,
Value: val,
}, nil
}
@@ -139,40 +149,45 @@ func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
// See also https://datatracker.ietf.org/doc/html/rfc1122,
// https://datatracker.ietf.org/doc/html/rfc1123, and
// https://datatracker.ietf.org/doc/html/rfc2132.
opts = dhcpv4.Options{
opts = dhcpv4.OptionsFromList(
// IP-Layer Per Host
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
dhcpv4.OptGeneric(dhcpv4.OptionNonLocalSourceRouting, []byte{0}),
// Set the current recommended default time to live for the
// Internet Protocol which is 64, see
// https://datatracker.ietf.org/doc/html/rfc1700.
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
dhcpv4.OptGeneric(dhcpv4.OptionDefaultIPTTL, []byte{0x40}),
// IP-Layer Per Interface
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
dhcpv4.OptGeneric(dhcpv4.OptionPerformMaskDiscovery, []byte{0}),
dhcpv4.OptGeneric(dhcpv4.OptionMaskSupplier, []byte{0}),
dhcpv4.OptGeneric(dhcpv4.OptionPerformRouterDiscovery, []byte{1}),
// The all-routers address is preferred wherever possible, see
// https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
dhcpv4.OptionRouterSolicitationAddress.Code(): netutil.IPv4allrouter(),
dhcpv4.OptionBroadcastAddress.Code(): netutil.IPv4bcast(),
dhcpv4.Option{
Code: dhcpv4.OptionRouterSolicitationAddress,
Value: dhcpv4.IP(netutil.IPv4allrouter()),
},
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
// Link-Layer Per Interface
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
dhcpv4.OptGeneric(dhcpv4.OptionTrailerEncapsulation, []byte{0}),
dhcpv4.OptGeneric(dhcpv4.OptionEthernetEncapsulation, []byte{0}),
// TCP Per Host
dhcpv4.OptionTCPKeepaliveInterval.Code(): dhcpv4.Duration(0).ToBytes(),
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
dhcpv4.Option{
Code: dhcpv4.OptionTCPKeepaliveInterval,
Value: dhcpv4.Duration(0),
},
dhcpv4.OptGeneric(dhcpv4.OptionTCPKeepaliveGarbage, []byte{0}),
// Values From Configuration
dhcpv4.OptionRouter.Code(): netutil.CloneIP(conf.subnet.IP),
dhcpv4.OptionSubnetMask.Code(): dhcpv4.IPMask(conf.subnet.Mask).ToBytes(),
}
dhcpv4.OptRouter(conf.subnet.IP),
dhcpv4.OptSubnetMask(conf.subnet.Mask),
)
// Set values for explicitly configured options.
for i, o := range conf.Options {

View File

@@ -8,103 +8,111 @@ import (
"net"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseOpt(t *testing.T) {
testCases := []struct {
name string
in string
wantErrMsg string
wantOpt dhcpv4.Option
wantErrMsg string
}{{
name: "hex_success",
in: "6 hex c0a80101c0a80102",
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0xC0, 0xA8, 0x01, 0x01},
net.IP{0xC0, 0xA8, 0x01, 0x02},
name: "hex_success",
in: "6 hex c0a80101c0a80102",
wantOpt: dhcpv4.OptGeneric(
dhcpv4.GenericOptionCode(6),
[]byte{
0xC0, 0xA8, 0x01, 0x01,
0xC0, 0xA8, 0x01, 0x02,
},
),
wantErrMsg: "",
}, {
name: "ip_success",
in: "6 ip 1.2.3.4",
name: "ip_success",
in: "6 ip 1.2.3.4",
wantOpt: dhcpv4.Option{
Code: dhcpv4.GenericOptionCode(6),
Value: dhcpv4.IP(net.IP{0x01, 0x02, 0x03, 0x04}),
},
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0x01, 0x02, 0x03, 0x04},
),
}, {
name: "ip_fail_v6",
in: "6 ip ::1234",
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
wantOpt: dhcpv4.Option{},
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
}, {
name: "ips_success",
in: "6 ips 192.168.1.1,192.168.1.2",
name: "ips_success",
in: "6 ips 192.168.1.1,192.168.1.2",
wantOpt: dhcpv4.Option{
Code: dhcpv4.GenericOptionCode(6),
Value: dhcpv4.IPs([]net.IP{
{0xC0, 0xA8, 0x01, 0x01},
{0xC0, 0xA8, 0x01, 0x02},
}),
},
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0xC0, 0xA8, 0x01, 0x01},
net.IP{0xC0, 0xA8, 0x01, 0x02},
),
}, {
name: "text_success",
in: "252 text http://192.168.1.1/",
wantErrMsg: "",
name: "text_success",
in: "252 text http://192.168.1.1/",
wantOpt: dhcpv4.OptGeneric(
dhcpv4.GenericOptionCode(252),
[]byte("http://192.168.1.1/"),
),
wantErrMsg: "",
}, {
name: "del_success",
in: "61 del",
wantOpt: dhcpv4.Option{
Code: dhcpv4.GenericOptionCode(dhcpv4.OptionClientIdentifier),
Value: dhcpv4.OptionGeneric{Data: nil},
},
wantErrMsg: "",
}, {
name: "bad_parts",
in: "6 ip",
wantErrMsg: `invalid option string "6 ip": need at least three fields`,
wantOpt: dhcpv4.Option{},
wantErrMsg: `invalid option string "6 ip": bad option format`,
}, {
name: "bad_code",
in: "256 ip 1.1.1.1",
name: "bad_code",
in: "256 ip 1.1.1.1",
wantOpt: dhcpv4.Option{},
wantErrMsg: `invalid option string "256 ip 1.1.1.1": parsing option code: ` +
`strconv.ParseUint: parsing "256": value out of range`,
wantOpt: dhcpv4.Option{},
}, {
name: "bad_type",
in: "6 bad 1.1.1.1",
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
wantOpt: dhcpv4.Option{},
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
}, {
name: "hex_error",
in: "6 hex ZZZ",
name: "hex_error",
in: "6 hex ZZZ",
wantOpt: dhcpv4.Option{},
wantErrMsg: `invalid option string "6 hex ZZZ": decoding hex: ` +
`encoding/hex: invalid byte: U+005A 'Z'`,
wantOpt: dhcpv4.Option{},
}, {
name: "ip_error",
in: "6 ip 1.2.3.x",
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
wantOpt: dhcpv4.Option{},
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
}, {
name: "ips_error",
in: "6 ips 192.168.1.1,192.168.1.x",
name: "ips_error",
in: "6 ips 192.168.1.1,192.168.1.x",
wantOpt: dhcpv4.Option{},
wantErrMsg: "invalid option string \"6 ips 192.168.1.1,192.168.1.x\": " +
"parsing ip at index 1: bad ipv4 address \"192.168.1.x\"",
wantOpt: dhcpv4.Option{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opt, err := parseDHCPOption(tc.in)
if tc.wantErrMsg != "" {
require.Error(t, err)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
return
}
require.NoError(t, err)
assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
// assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
// assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
assert.Equal(t, tc.wantOpt, opt)
})
}
}
@@ -127,51 +135,80 @@ func TestPrepareOptions(t *testing.T) {
testCases := []struct {
name string
opts []string
checks dhcpv4.Options
opts []string
}{{
name: "all_default",
checks: allDefault,
opts: nil,
}, {
name: "configured_ip",
opts: []string{
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
},
checks: dhcpv4.Options{
dhcpv4.OptionBroadcastAddress.Code(): oneIP,
},
opts: []string{
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
},
}, {
name: "configured_ips",
opts: []string{
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
},
checks: dhcpv4.Options{
dhcpv4.OptionDomainNameServer.Code(): append(oneIP, otherIP...),
},
opts: []string{
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
},
}, {
name: "configured_bad",
name: "configured_bad",
checks: allDefault,
opts: []string{
"20 hex",
"23 hex abc",
"32 ips 1,2,3,4",
"28 256.256.256.256",
},
checks: allDefault,
}, {
name: "configured_del",
checks: dhcpv4.Options{
dhcpv4.OptionBroadcastAddress.Code(): nil,
},
opts: []string{
"28 del",
},
}, {
name: "rewritten_del",
checks: dhcpv4.Options{
dhcpv4.OptionBroadcastAddress.Code(): []byte{255, 255, 255, 255},
},
opts: []string{
"28 del",
"28 ip 255.255.255.255",
},
}, {
name: "configured_and_del",
checks: dhcpv4.Options{
123: []byte("cba"),
},
opts: []string{
"123 text abc",
"123 del",
"123 text cba",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.name == "configured_del" {
assert.True(t, true)
}
opts := prepareOptions(V4ServerConf{
// Just to avoid nil pointer dereference.
subnet: &net.IPNet{},
Options: tc.opts,
})
for c, v := range tc.checks {
optVal := opts.Get(dhcpv4.GenericOptionCode(c))
require.NotNil(t, optVal)
assert.Len(t, optVal, len(v))
assert.Equal(t, v, optVal)
val := opts.Get(dhcpv4.GenericOptionCode(c))
assert.Lenf(t, val, len(v), "Code: %v", c)
assert.Equal(t, v, val)
}
})
}

View File

@@ -20,6 +20,9 @@ import (
"github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"golang.org/x/exp/slices"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
@@ -91,6 +94,9 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna
if hostname == "" {
hostname = aghnet.GenerateHostname(ip)
} else if s.leaseHosts.Has(hostname) {
log.Info("dhcpv4: hostname %q already exists", hostname)
hostname = aghnet.GenerateHostname(ip)
}
err = netutil.ValidateDomainName(hostname)
@@ -250,11 +256,11 @@ func (s *v4Server) rmLeaseByIndex(i int) {
// Remove a dynamic lease with the same properties
// Return error if a static lease is found
func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
for i := 0; i < len(s.leases); i++ {
l := s.leases[i]
for i, l := range s.leases {
isStatic := l.IsStatic()
if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.IsStatic() {
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
if isStatic {
return errors.Error("static lease already exists")
}
@@ -266,20 +272,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
l = s.leases[i]
}
if l.IP.Equal(lease.IP) {
if l.IsStatic() {
return errors.Error("static lease already exists")
}
s.rmLeaseByIndex(i)
if i == len(s.leases) {
break
}
l = s.leases[i]
}
if l.Hostname == lease.Hostname {
if !isStatic && l.Hostname == lease.Hostname {
l.Hostname = ""
}
}
@@ -287,6 +280,10 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
return nil
}
// ErrDupHostname is returned by addLease when the added lease has a not empty
// non-unique hostname.
const ErrDupHostname = errors.Error("hostname is not unique")
// addLease adds a dynamic or static lease.
func (s *v4Server) addLease(l *Lease) (err error) {
r := s.conf.ipRange
@@ -302,13 +299,17 @@ func (s *v4Server) addLease(l *Lease) (err error) {
return fmt.Errorf("lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
}
s.leases = append(s.leases, l)
s.leasedOffsets.set(offset, true)
if l.Hostname != "" {
if s.leaseHosts.Has(l.Hostname) {
return ErrDupHostname
}
s.leaseHosts.Add(l.Hostname)
}
s.leases = append(s.leases, l)
s.leasedOffsets.set(offset, true)
return nil
}
@@ -333,12 +334,16 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
return errors.Error("lease not found")
}
// AddStaticLease adds a static lease. It is safe for concurrent use.
// AddStaticLease implements the DHCPServer interface for *v4Server. It is safe
// for concurrent use.
func (s *v4Server) AddStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
if ip4 := l.IP.To4(); ip4 == nil {
ip := l.IP.To4()
if ip == nil {
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
} else if gwIP := s.conf.GatewayIP; gwIP.Equal(ip) {
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
}
l.Expiry = time.Unix(leaseExpireStatic, 0)
@@ -359,10 +364,11 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
return fmt.Errorf("validating hostname: %w", err)
}
// Don't check for hostname uniqueness, since we try to emulate
// dnsmasq here, which means that rmDynamicLease below will
// simply empty the hostname of the dynamic lease if there even
// is one.
// Don't check for hostname uniqueness, since we try to emulate dnsmasq
// here, which means that rmDynamicLease below will simply empty the
// hostname of the dynamic lease if there even is one. In case a static
// lease with the same name already exists, addLease will return an
// error and the lease won't be added.
l.Hostname = hostname
}
@@ -377,7 +383,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
if err != nil {
err = fmt.Errorf(
"removing dynamic leases for %s (%s): %w",
l.IP,
ip,
l.HWAddr,
err,
)
@@ -387,7 +393,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
err = s.addLease(l)
if err != nil {
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
return
}
@@ -517,11 +523,7 @@ func (s *v4Server) findExpiredLease() int {
// reserveLease reserves a lease for a client by its MAC-address. It returns
// nil if it couldn't allocate a new lease.
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
l = &Lease{
HWAddr: make([]byte, len(mac)),
}
copy(l.HWAddr, mac)
l = &Lease{HWAddr: slices.Clone(mac)}
l.IP = s.nextIP()
if l.IP == nil {
@@ -614,33 +616,25 @@ func (s *v4Server) processDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err erro
return l, nil
}
type optFQDN struct {
name string
}
// OptionFQDN returns a DHCPv4 option for sending the FQDN to the client
// requested another hostname.
//
// See https://datatracker.ietf.org/doc/html/rfc4702.
func OptionFQDN(fqdn string) (opt dhcpv4.Option) {
optData := []byte{
// Set only S and O DHCP client FQDN option flags.
//
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.1.
1<<0 | 1<<1,
// The RCODE fields should be set to 0xFF in the server responses.
//
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.2.
0xFF,
0xFF,
}
optData = append(optData, fqdn...)
func (o *optFQDN) String() string {
return "optFQDN"
}
// flags[1]
// A-RR[1]
// PTR-RR[1]
// name[]
func (o *optFQDN) ToBytes() []byte {
b := make([]byte, 3+len(o.name))
i := 0
b[i] = 0x03 // f_server_overrides | f_server
i++
b[i] = 255 // A-RR
i++
b[i] = 255 // PTR-RR
i++
copy(b[i:], []byte(o.name))
return b
return dhcpv4.OptGeneric(dhcpv4.OptionFQDN, optData)
}
// checkLease checks if the pair of mac and ip is already leased. The mismatch
@@ -673,6 +667,8 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
// processRequest is the handler for the DHCP Request request.
func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needsReply bool) {
mac := req.ClientHWAddr
// TODO(e.burkov): The IP address can only be requested in DHCPDISCOVER
// message.
reqIP := req.RequestedIPAddress()
if reqIP == nil {
reqIP = req.ClientIPAddr
@@ -705,24 +701,17 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needs
if !lease.IsStatic() {
cliHostname := req.HostName()
hostname := s.validHostnameForClient(cliHostname, reqIP)
if hostname != lease.Hostname && s.leaseHosts.Has(hostname) {
log.Info("dhcpv4: hostname %q already exists", hostname)
lease.Hostname = ""
} else {
if lease.Hostname != hostname {
lease.Hostname = hostname
resp.UpdateOption(dhcpv4.OptHostName(hostname))
}
s.commitLease(lease)
} else if lease.Hostname != "" {
o := &optFQDN{
name: lease.Hostname,
}
fqdn := dhcpv4.Option{
Code: dhcpv4.OptionFQDN,
Value: o,
}
resp.UpdateOption(fqdn)
// TODO(e.burkov): This option is used to update the server's DNS
// mapping. The option should only be answered when it has been
// requested.
resp.UpdateOption(OptionFQDN(lease.Hostname))
}
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
@@ -845,7 +834,7 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
// TODO(a.garipov): Refactor this into handlers.
var l *Lease
switch req.MessageType() {
switch mt := req.MessageType(); mt {
case dhcpv4.MessageTypeDiscover:
l, err = s.processDiscover(req, resp)
if err != nil {
@@ -886,28 +875,32 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
resp.YourIPAddr = netutil.CloneIP(l.IP)
}
// Set IP address lease time for all DHCPOFFER messages and DHCPACK
// messages replied for DHCPREQUEST.
// Set IP address lease time for all DHCPOFFER messages and DHCPACK messages
// replied for DHCPREQUEST.
//
// TODO(e.burkov): Inspect why this is always set to configured value.
resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime))
// Delete options explicitly configured to be removed.
for code := range resp.Options {
if val, ok := s.options[code]; ok && val == nil {
delete(resp.Options, code)
}
}
// Update values for each explicitly configured parameter requested by
// client.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#section-4.3.1.
requested := req.ParameterRequestList()
for _, code := range requested {
if configured := s.options; configured.Has(code) {
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
if val := s.options.Get(code); val != nil {
resp.UpdateOption(dhcpv4.Option{
Code: code,
Value: dhcpv4.OptionGeneric{Data: s.options.Get(code)},
})
}
}
// Update the value of Domain Name Server option separately from others if
// not assigned yet since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) &&
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
}
return 1
}
@@ -935,6 +928,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
resp, err := dhcpv4.NewReplyFromRequest(req)
if err != nil {
log.Debug("dhcpv4: dhcpv4.New: %s", err)
return
}
@@ -1031,6 +1025,11 @@ func (s *v4Server) Start() (err error) {
// No available IP addresses which may appear later.
return nil
}
// Update the value of Domain Name Server option separately from others if
// not assigned yet since its value is available only at server's start.
if !s.options.Has(dhcpv4.OptionDomainNameServer) {
s.options.Update(dhcpv4.OptDNS(dnsIPAddrs...))
}
s.conf.dnsIPAddrs = dnsIPAddrs

View File

@@ -4,16 +4,29 @@
package dhcpd
import (
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
var (
DefaultRangeStart = net.IP{192, 168, 10, 100}
DefaultRangeEnd = net.IP{192, 168, 10, 200}
DefaultGatewayIP = net.IP{192, 168, 10, 1}
DefaultSelfIP = net.IP{192, 168, 10, 2}
DefaultSubnetMask = net.IP{255, 255, 255, 0}
)
func notify4(flags uint32) {
@@ -24,11 +37,12 @@ func notify4(flags uint32) {
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
RangeStart: DefaultRangeStart,
RangeEnd: DefaultRangeEnd,
GatewayIP: DefaultGatewayIP,
SubnetMask: DefaultSubnetMask,
notify: notify4,
dnsIPAddrs: []net.IP{DefaultSelfIP},
}
}
@@ -44,44 +58,228 @@ func defaultSrv(t *testing.T) (s DHCPServer) {
return s
}
func TestV4_AddRemove_static(t *testing.T) {
func TestV4Server_leasing(t *testing.T) {
const (
staticName = "static-client"
anotherName = "another-client"
)
staticIP := net.IP{192, 168, 10, 10}
anotherIP := DefaultRangeStart
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
s := defaultSrv(t)
t.Run("add_static", func(t *testing.T) {
err := s.AddStaticLease(&Lease{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: staticMAC,
IP: staticIP,
})
require.NoError(t, err)
t.Run("same_name", func(t *testing.T) {
err = s.AddStaticLease(&Lease{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: staticName,
HWAddr: anotherMAC,
IP: anotherIP,
})
assert.ErrorIs(t, err, ErrDupHostname)
})
t.Run("same_mac", func(t *testing.T) {
wantErrMsg := "dhcpv4: adding static lease: removing " +
"dynamic leases for " + anotherIP.String() +
" (" + staticMAC.String() + "): static lease already exists"
err = s.AddStaticLease(&Lease{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: anotherName,
HWAddr: staticMAC,
IP: anotherIP,
})
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
t.Run("same_ip", func(t *testing.T) {
wantErrMsg := "dhcpv4: adding static lease: removing " +
"dynamic leases for " + staticIP.String() +
" (" + anotherMAC.String() + "): static lease already exists"
err = s.AddStaticLease(&Lease{
Expiry: time.Unix(leaseExpireStatic, 0),
Hostname: anotherName,
HWAddr: anotherMAC,
IP: staticIP,
})
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
})
t.Run("add_dynamic", func(t *testing.T) {
s4, ok := s.(*v4Server)
require.True(t, ok)
discoverAnOffer := func(
t *testing.T,
name string,
ip net.IP,
mac net.HardwareAddr,
) (resp *dhcpv4.DHCPv4) {
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.ResetLeases(s.GetLeases(LeasesStatic))
})
req, err := dhcpv4.NewDiscovery(
mac,
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(ip)),
dhcpv4.WithOption(dhcpv4.OptClientIdentifier([]byte{1, 2, 3, 4, 5, 6, 8})),
dhcpv4.WithGatewayIP(DefaultGatewayIP),
)
require.NoError(t, err)
resp = &dhcpv4.DHCPv4{}
res := s4.process(req, resp)
require.Positive(t, res)
require.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
resp.ClientHWAddr = mac
return resp
}
t.Run("same_name", func(t *testing.T) {
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
dhcpv4.OptHostName(staticName),
))
require.NoError(t, err)
res := s4.process(req, resp)
require.Positive(t, res)
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
})
t.Run("same_mac", func(t *testing.T) {
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
dhcpv4.OptHostName(anotherName),
))
require.NoError(t, err)
res := s4.process(req, resp)
require.Positive(t, res)
fqdnOptData := resp.Options.Get(dhcpv4.OptionFQDN)
require.Len(t, fqdnOptData, 3+len(staticName))
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
assert.Equal(t, staticIP, resp.YourIPAddr)
})
t.Run("same_ip", func(t *testing.T) {
resp := discoverAnOffer(t, anotherName, staticIP, anotherMAC)
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
dhcpv4.OptHostName(anotherName),
))
require.NoError(t, err)
res := s4.process(req, resp)
require.Positive(t, res)
assert.NotEqual(t, staticIP, resp.YourIPAddr)
})
})
}
func TestV4Server_AddRemove_static(t *testing.T) {
s := defaultSrv(t)
ls := s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
require.Empty(t, ls)
// Add static lease.
l := &Lease{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
testCases := []struct {
lease *Lease
name string
wantErrMsg string
}{{
lease: &Lease{
Hostname: "success.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
},
name: "success",
wantErrMsg: "",
}, {
lease: &Lease{
Hostname: "probably-router.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: DefaultGatewayIP,
},
name: "with_gateway_ip",
wantErrMsg: "dhcpv4: adding static lease: " +
"can't assign the gateway IP 192.168.10.1 to the lease",
}, {
lease: &Lease{
Hostname: "ip6.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.ParseIP("ffff::1"),
},
name: "ipv6",
wantErrMsg: `dhcpv4: adding static lease: ` +
`invalid ip "ffff::1", only ipv4 is supported`,
}, {
lease: &Lease{
Hostname: "bad-mac.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
},
name: "bad_mac",
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
`bad mac address length 2, allowed: [6 8 20]`,
}, {
lease: &Lease{
Hostname: "bad-lbl-.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
},
name: "bad_hostname",
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
`bad domain name "bad-lbl-.local": ` +
`bad domain name label "bad-lbl-": bad domain name label rune '-'`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := s.AddStaticLease(tc.lease)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
if tc.wantErrMsg != "" {
return
}
err = s.RemoveStaticLease(&Lease{
IP: tc.lease.IP,
HWAddr: tc.lease.HWAddr,
})
diffErrMsg := fmt.Sprintf("dhcpv4: lease for ip %s is different: %+v", tc.lease.IP, tc.lease)
testutil.AssertErrorMsg(t, diffErrMsg, err)
// Remove static lease.
err = s.RemoveStaticLease(tc.lease)
require.NoError(t, err)
})
ls = s.GetLeases(LeasesStatic)
require.Emptyf(t, ls, "after %s", tc.name)
}
err := s.AddStaticLease(l)
require.NoError(t, err)
err = s.AddStaticLease(l)
assert.Error(t, err)
ls = s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
assert.True(t, ls[0].IsStatic())
// Try to remove static lease.
err = s.RemoveStaticLease(&Lease{
IP: net.IP{192, 168, 10, 110},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
})
assert.Error(t, err)
// Remove static lease.
err = s.RemoveStaticLease(l)
require.NoError(t, err)
ls = s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
}
func TestV4_AddReplace(t *testing.T) {
@@ -147,6 +345,8 @@ func TestV4Server_Process_optionsPriority(t *testing.T) {
stringutil.WriteToBuilder(b, ",", ip.String())
}
conf.Options = []string{b.String()}
} else {
defer func() { s.options.Update(dhcpv4.OptDNS(defaultIP)) }()
}
ss, err := v4Create(conf)
@@ -209,6 +409,7 @@ func TestV4StaticLease_Get(t *testing.T) {
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
s.options.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
l := &Lease{
Hostname: "static-1.local",
@@ -297,6 +498,7 @@ func TestV4DynamicLease_Get(t *testing.T) {
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
s.options.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
var req, resp *dhcpv4.DHCPv4
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}

View File

@@ -214,7 +214,7 @@ func validateAccessSet(list *accessListJSON) (err error) {
}
merged := allowed.Merge(disallowed)
err = merged.Validate(aghalg.StringIsBefore)
err = merged.Validate()
if err != nil {
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
}
@@ -223,13 +223,13 @@ func validateAccessSet(list *accessListJSON) (err error) {
}
// validateStrUniq returns an informative error if clients are not unique.
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
uc = make(aghalg.UniqChecker, len(clients))
func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error) {
uc = make(aghalg.UniqChecker[string], len(clients))
for _, c := range clients {
uc.Add(c)
}
return uc, uc.Validate(aghalg.StringIsBefore)
return uc, uc.Validate()
}
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {

View File

@@ -65,7 +65,7 @@ func clientIDFromClientServerName(
return "", err
}
return clientID, nil
return strings.ToLower(clientID), nil
}
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
@@ -104,7 +104,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
return strings.ToLower(clientID), nil
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
@@ -112,8 +112,8 @@ type tlsConn interface {
ConnectionState() (cs tls.ConnectionState)
}
// quicSession is a narrow interface for quic.Session to simplify testing.
type quicSession interface {
// quicConnection is a narrow interface for quic.Connection to simplify testing.
type quicConnection interface {
ConnectionState() (cs quic.ConnectionState)
}
@@ -148,16 +148,16 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
cliSrvName = tc.ConnectionState().ServerName
case proxy.ProtoQUIC:
qs, ok := pctx.QUICSession.(quicSession)
conn, ok := pctx.QUICConnection.(quicConnection)
if !ok {
return "", fmt.Errorf(
"proxy ctx quic session of proto %s is %T, want quic.Session",
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
proto,
pctx.QUICSession,
pctx.QUICConnection,
)
}
cliSrvName = qs.ConnectionState().TLS.ServerName
cliSrvName = conn.ConnectionState().TLS.ServerName
}
clientID, err = clientIDFromClientServerName(

View File

@@ -29,17 +29,18 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
return cs
}
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession a quic.Session
// without actually implementing all methods.
quic.Session
// testQUICConnection is a quicConnection for tests.
type testQUICConnection struct {
// Connection is embedded here simply to make testQUICConnection a
// quic.Connection without actually implementing all methods.
quic.Connection
serverName string
}
// ConnectionState implements the quicSession interface for testQUICSession.
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
// ConnectionState implements the quicConnection interface for
// testQUICConnection.
func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
cs.TLS.ServerName = c.serverName
return cs
@@ -143,6 +144,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
`doesn't match host server name "example.com"`,
strictSNI: true,
}, {
name: "tls_case",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "InSeNsItIvE.example.com",
wantClientID: "insensitive",
wantErrMsg: ``,
strictSNI: true,
}, {
name: "quic_case",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "InSeNsItIvE.example.com",
wantClientID: "insensitive",
wantErrMsg: ``,
strictSNI: true,
}}
for _, tc := range testCases {
@@ -163,17 +180,17 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
}
}
var qs quic.Session
var qconn quic.Connection
if tc.proto == proxy.ProtoQUIC {
qs = testQUICSession{
qconn = testQUICConnection{
serverName: tc.cliSrvName,
}
}
pctx := &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
Proto: tc.proto,
Conn: conn,
QUICConnection: qconn,
}
clientID, err := srv.clientIDFromDNSContext(pctx)
@@ -210,6 +227,11 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
}, {
name: "clientid_case",
path: "/dns-query/InSeNsItIvE",
wantClientID: "insensitive",
wantErrMsg: ``,
}, {
name: "bad_url",
path: "/foo",

View File

@@ -5,12 +5,12 @@ import (
"crypto/x509"
"fmt"
"net"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -122,6 +122,7 @@ type FilteringConfig struct {
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests
// IpsetList is the ipset configuration that allows AdGuard Home to add
// IP addresses of the specified domain names to an ipset list. Syntax:
@@ -133,8 +134,9 @@ type FilteringConfig struct {
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
type TLSConfig struct {
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
@@ -151,7 +153,7 @@ type TLSConfig struct {
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only being
// used for ClientID checking.
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
@@ -191,7 +193,7 @@ type ServerConfig struct {
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
HTTPRegister aghhttp.RegisterFunc
// ResolveClients signals if the RDNS should resolve clients' addresses.
ResolveClients bool
@@ -276,6 +278,11 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
return proxyConfig, nil
}
const (
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
// initDefaultSettings initializes default settings if nothing
// is configured
func (s *Server) initDefaultSettings() {
@@ -287,12 +294,12 @@ func (s *Server) initDefaultSettings() {
s.conf.BootstrapDNS = defaultBootstrap
}
if len(s.conf.ParentalBlockHost) == 0 {
s.conf.ParentalBlockHost = parentalBlockHost
if s.conf.ParentalBlockHost == "" {
s.conf.ParentalBlockHost = defaultParentalBlockHost
}
if len(s.conf.SafeBrowsingBlockHost) == 0 {
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
if s.conf.SafeBrowsingBlockHost == "" {
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
}
if s.conf.UDPListenAddrs == nil {

View File

@@ -76,6 +76,10 @@ const (
resultCodeError
)
// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests.
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
const ddrHostFQDN = "_dns.resolver.arpa."
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{
@@ -94,10 +98,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
mods := []modProcessFunc{
s.processRecursion,
s.processInitial,
s.processDDRQuery,
s.processDetermineLocal,
s.processInternalHosts,
s.processDHCPHosts,
s.processRestrictLocal,
s.processInternalIPAddrs,
s.processDHCPAddrs,
s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream,
@@ -135,7 +140,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
@@ -226,12 +230,10 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
)
}
lowhost := strings.ToLower(l.Hostname)
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
ip := netutil.CloneIP(l.IP)
ipToHost.Set(l.IP, lowhost)
ip := make(net.IP, 4)
copy(ip, l.IP.To4())
ipToHost.Set(ip, lowhost)
hostToIP[lowhost] = ip
}
@@ -242,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
s.setTableIPToHost(ipToHost)
}
// processDDRQuery responds to SVCB query for a special use domain name
// _dns.resolver.arpa. The result contains different types of encryption
// supported by current user configuration.
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
question := d.Req.Question[0]
if !s.conf.HandleDDR {
return resultCodeSuccess
}
if question.Name == ddrHostFQDN {
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
d.Res = s.makeResponse(d.Req)
return resultCodeFinish
}
d.Res = s.makeDDRResponse(d.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// makeDDRResponse creates DDR answer according to server configuration. The
// contructed SVCB resource records have the priority of 1 for each entry,
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
//
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req)
// TODO(e.burkov): Think about storing the FQDN version of the server's
// name somewhere.
domainName := dns.Fqdn(s.conf.ServerName)
for _, addr := range s.conf.HTTPSListenAddrs {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.TLSListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.QUICListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
return resp
}
// processDetermineLocal determines if the client's IP address is from
// locally-served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
@@ -252,7 +346,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
return rc
}
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
dctx.isLocalClient = s.privateNets.Contains(ip)
return rc
}
@@ -280,11 +374,11 @@ func (s *Server) hostToIP(host string) (ip net.IP, ok bool) {
return ip, true
}
// processInternalHosts respond to A requests if the target hostname is known to
// processDHCPHosts respond to A requests if the target hostname is known to
// the server.
//
// TODO(a.garipov): Adapt to AAAA as well.
func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
if !s.dhcpServer.Enabled() {
return resultCodeSuccess
}
@@ -299,11 +393,10 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
reqHost := strings.ToLower(q.Name)
reqHost := strings.ToLower(q.Name[:len(q.Name)-1])
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
// server.
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
if host == reqHost {
if !strings.HasSuffix(reqHost, s.localDomainSuffix) {
return resultCodeSuccess
}
@@ -316,7 +409,7 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
ip, ok := s.hostToIP(host)
ip, ok := s.hostToIP(reqHost)
if !ok {
// TODO(e.burkov): Inspect special cases when user want to apply some
// rules handled by other processors to the hosts with TLD.
@@ -373,8 +466,8 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at least
// don't need to be inaccessible externally.
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
// don't need to be accessible externally.
if !s.privateNets.Contains(ip) {
log.Debug("dns: addr %s is not from locally-served network", ip)
return resultCodeSuccess
@@ -413,7 +506,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
return "", false
}
var v interface{}
var v any
v, ok = s.tableIPToHost.Get(ip)
if !ok {
return "", false
@@ -430,7 +523,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
// Respond to PTR requests if the target IP is leased by our DHCP server and the
// requestor is inside the local network.
func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
func (s *Server) processDHCPAddrs(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if d.Res != nil {
return resultCodeSuccess
@@ -481,7 +574,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
if !s.privateNets.Contains(ip) {
return resultCodeSuccess
}

View File

@@ -4,35 +4,212 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_ProcessDetermineLocal(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
s := &Server{
subnetDetector: snd,
const (
ddrTestDomainName = "dns.example.net"
ddrTestFQDN = ddrTestDomainName + "."
)
func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: 8044},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
},
}
dotSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: 8043},
},
}
doqSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: 8042},
},
}
testCases := []struct {
name string
host string
want []*dns.SVCB
wantRes resultCode
portDoH int
portDoT int
portDoQ int
qtype uint16
ddrEnabled bool
}{{
name: "pass_host",
wantRes: resultCodeSuccess,
host: "example.net.",
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_qtype",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeA,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_disabled_tls",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
}, {
name: "pass_disabled_ddr",
wantRes: resultCodeSuccess,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: false,
portDoH: 8043,
}, {
name: "dot",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
}, {
name: "doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8044,
}, {
name: "doq",
wantRes: resultCodeFinish,
want: []*dns.SVCB{doqSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoQ: 8042,
}, {
name: "dot_doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB, dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
portDoH: 8044,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
req := createTestMessageWithType(tc.host, tc.qtype)
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
}
res := s.processDDRQuery(dctx)
require.Equal(t, tc.wantRes, res)
if tc.wantRes != resultCodeFinish {
return
}
msg := dctx.proxyCtx.Res
require.NotNil(t, msg)
for _, v := range tc.want {
v.Hdr = s.hdr(req, dns.TypeSVCB)
}
assert.ElementsMatch(t, tc.want, msg.Answer)
})
}
}
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
proxyConf := proxy.Config{}
if portDoT > 0 {
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
}
if portDoQ > 0 {
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
}
s = &Server{
dnsProxy: &proxy.Proxy{
Config: proxyConf,
},
conf: ServerConfig{
FilteringConfig: FilteringConfig{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
},
}
if portDoH > 0 {
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
}
return s
}
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliIP net.IP
want bool
}{{
want: assert.True,
name: "local",
cliIP: net.IP{192, 168, 0, 1},
want: true,
}, {
want: assert.False,
name: "external",
cliIP: net.IP{250, 249, 0, 1},
want: false,
}, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
}}
for _, tc := range testCases {
@@ -47,12 +224,12 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
s.processDetermineLocal(dctx)
assert.Equal(t, tc.want, dctx.isLocalClient)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
testCases := []struct {
@@ -93,7 +270,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
dhcpServer: &testDHCP{},
localDomainSuffix: defaultLocalDomainSuffix,
tableHostToIP: hostToIPTable{
"example": knownIP,
"example." + defaultLocalDomainSuffix: knownIP,
},
}
@@ -115,7 +292,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
isLocalClient: tc.isLocalCli,
}
res := s.processInternalHosts(dctx)
res := s.processDHCPHosts(dctx)
require.Equal(t, tc.wantRes, res)
pctx := dctx.proxyCtx
if tc.wantRes == resultCodeFinish {
@@ -141,10 +318,10 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
}
}
func TestServer_ProcessInternalHosts(t *testing.T) {
func TestServer_ProcessDHCPHosts(t *testing.T) {
const (
examplecom = "example.com"
examplelan = "example.lan"
examplelan = "example." + defaultLocalDomainSuffix
)
knownIP := net.IP{1, 2, 3, 4}
@@ -193,41 +370,41 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
}, {
name: "success_custom_suffix",
host: "example.custom",
suffix: ".custom.",
suffix: "custom",
wantIP: knownIP,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}}
for _, tc := range testCases {
s := &Server{
dhcpServer: &testDHCP{},
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example." + tc.suffix: knownIP,
},
}
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: dns.Fqdn(tc.host),
Qtype: tc.qtyp,
Qclass: dns.ClassINET,
}},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
isLocalClient: true,
}
t.Run(tc.name, func(t *testing.T) {
s := &Server{
dhcpServer: &testDHCP{},
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example": knownIP,
},
}
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: dns.Fqdn(tc.host),
Qtype: tc.qtyp,
Qclass: dns.ClassINET,
}},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
isLocalClient: true,
}
res := s.processInternalHosts(dctx)
res := s.processDHCPHosts(dctx)
pctx := dctx.proxyCtx
assert.Equal(t, tc.wantRes, res)
if tc.wantRes == resultCodeFinish {

View File

@@ -33,11 +33,6 @@ const DefaultTimeout = 10 * time.Second
// requests between the BeforeRequestHandler stage and the actual processing.
const defaultClientIDCacheCount = 1024
const (
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
parentalBlockHost = "family-block.dns.adguard.com"
)
var defaultDNS = []string{
"https://dns10.quad9.net/dns-query",
}
@@ -66,7 +61,7 @@ type Server struct {
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
stats stats.Stats
stats stats.Interface
access *accessCtx
// localDomainSuffix is the suffix used to detect internal hosts. It
@@ -74,7 +69,7 @@ type Server struct {
localDomainSuffix string
ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector
privateNets netutil.SubnetSet
localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers
recDetector *recursionDetector
@@ -107,28 +102,17 @@ type Server struct {
// when no suffix is provided.
//
// See the documentation for Server.localDomainSuffix.
const defaultLocalDomainSuffix = ".lan."
const defaultLocalDomainSuffix = "lan"
// DNSCreateParams are parameters to create a new server.
type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Stats
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
SubnetDetector *aghnet.SubnetDetector
Anonymizer *aghnet.IPMut
LocalDomain string
}
// domainNameToSuffix converts a domain name into a local domain suffix.
func domainNameToSuffix(tld string) (suffix string) {
l := len(tld) + 2
b := make([]byte, l)
b[0] = '.'
copy(b[1:], tld)
b[l-1] = '.'
return string(b)
DNSFilter *filtering.DNSFilter
Stats stats.Interface
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
LocalDomain string
}
const (
@@ -151,7 +135,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
return nil, fmt.Errorf("local domain: %w", err)
}
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
localDomainSuffix = p.LocalDomain
}
if p.Anonymizer == nil {
@@ -161,7 +145,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
subnetDetector: p.SubnetDetector,
privateNets: p.PrivateNets,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
@@ -173,7 +157,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Enable the refresher after the actual implementation
// passes the public testing.
s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil)
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
if err != nil {
return nil, fmt.Errorf("initializing system resolvers: %w", err)
}
@@ -314,14 +298,16 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
StartTime: time.Now(),
}
resolver := s.internalProxy
if s.subnetDetector.IsLocallyServedNetwork(ip) {
var resolver *proxy.Proxy
if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS {
return "", nil
}
resolver = s.localResolvers
s.recDetector.add(*req)
} else {
resolver = s.internalProxy
}
if err = resolver.Resolve(ctx); err != nil {

View File

@@ -17,13 +17,14 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
@@ -69,14 +70,11 @@ func createTestServer(
f := filtering.New(filterConf, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var err error
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
f := filtering.New(&filtering.Config{}, filters)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -860,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru"
sbUps := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: true,
}
sbUps := aghtest.NewBlockUpstream(hostname, true)
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &filtering.Config{
@@ -913,15 +903,10 @@ func TestRewrite(t *testing.T) {
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1000,7 +985,7 @@ func TestRewrite(t *testing.T) {
}
}
func publicKey(priv interface{}) interface{} {
func publicKey(priv any) any {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
@@ -1028,36 +1013,33 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
const localDomain = "lan"
var s *Server
s, err = NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
LocalDomain: localDomain,
})
require.NoError(t, err)
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
s.conf.ProtectionEnabled = true
err = s.Prepare(nil)
require.NoError(t, err)
err = s.Start()
require.NoError(t, err)
t.Cleanup(s.Close)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String())
require.NoError(t, err)
require.NoErrorf(t, err, "%s", addr)
require.Len(t, resp.Answer, 1)
@@ -1066,7 +1048,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok)
assert.Equal(t, "myhost.", ptr.Ptr)
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
}
func TestPTRResponseFromHosts(t *testing.T) {
@@ -1105,16 +1087,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: flt,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1197,25 +1174,48 @@ func TestNewServer(t *testing.T) {
}
func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
const (
onesHost = "one.one.one.one"
localDomainHost = "local.domain"
)
var (
onesIP = net.IP{1, 1, 1, 1}
localIP = net.IP{192, 168, 1, 1}
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
},
}
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.168.192.in-addr.arpa.": {"local.domain"},
"2.1.168.192.in-addr.arpa.": {},
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
},
}
upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{
Err: upstreamErr,
}
nonPtrUpstream := &aghtest.TestBlockUpstream{
Hostname: "some-host",
Block: true,
}
errUpstream := aghtest.NewErrorUpstream()
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
srv := NewCustomServer(&proxy.Proxy{
Config: proxy.Config{
@@ -1227,11 +1227,8 @@ func TestServer_Exchange(t *testing.T) {
srv.conf.ResolveClients = true
srv.conf.UsePrivateRDNS = true
var err error
srv.subnetDetector, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := net.IP{192, 168, 1, 1}
testCases := []struct {
name string
want string
@@ -1240,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
req net.IP
}{{
name: "external_good",
want: "one.one.one.one",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: net.IP{1, 1, 1, 1},
req: onesIP,
}, {
name: "local_good",
want: "local.domain",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpstream,
req: localIP,
}, {
name: "upstream_error",
want: "",
wantErr: upstreamErr,
wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream,
req: localIP,
}, {

View File

@@ -22,7 +22,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
Preference: 32,
}
svcbVal := &rules.DNSSVCB{
Params: map[string]string{"alpn": "h3"},
Params: map[string]string{"alpn": "h3", "dohpath": "/dns-query"},
Target: dns.Fqdn(domain),
Priority: 32,
}
@@ -164,10 +164,20 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
require.Len(t, d.Res.Answer, 1)
ans, ok := d.Res.Answer[0].(*dns.SVCB)
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
require.True(t, ok)
require.Len(t, ans.Value, 2)
assert.ElementsMatch(
t,
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
)
assert.ElementsMatch(
t,
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
[]string{ans.Value[0].String(), ans.Value[1].String()},
)
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})
@@ -186,8 +196,18 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
require.True(t, ok)
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcbVal.Params["alpn"], ans.Value[0].String())
require.Len(t, ans.Value, 2)
assert.ElementsMatch(
t,
[]dns.SVCBKey{dns.SVCB_ALPN, dns.SVCB_DOHPATH},
[]dns.SVCBKey{ans.Value[0].Key(), ans.Value[1].Key()},
)
assert.ElementsMatch(
t,
[]string{svcbVal.Params["alpn"], svcbVal.Params["dohpath"]},
[]string{ans.Value[0].String(), ans.Value[1].String()},
)
assert.Equal(t, svcbVal.Target, ans.Target)
assert.Equal(t, svcbVal.Priority, ans.Priority)
})

View File

@@ -4,7 +4,6 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)

View File

@@ -5,12 +5,10 @@ import (
"fmt"
"net"
"net/http"
"sort"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -18,6 +16,8 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
type dnsConfig struct {
@@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
}
// validate returns an error if any field of req is invalid.
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
if err != nil {
@@ -176,7 +176,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err)
}
@@ -224,7 +224,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
err = req.validate(s.subnetDetector)
err = req.validate(s.privateNets)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -350,17 +350,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// LocalNetChecker is used to check if the IP address belongs to a local
// network.
type LocalNetChecker interface {
// IsLocallyServedNetwork returns true if ip is contained in any of address
// registries defined by RFC 6303.
IsLocallyServedNetwork(ip net.IP) (ok bool)
}
// type check
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
@@ -375,6 +364,21 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
return nil, nil
}
for _, u := range upstreams {
var ups string
var domains []string
ups, domains, err = separateUpstream(u)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
_, err = validateUpstream(ups, domains)
if err != nil {
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
}
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
@@ -385,13 +389,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
return nil, errors.Error("no default upstreams specified")
}
for _, u := range upstreams {
_, err = validateUpstream(u)
if err != nil {
return nil, err
}
}
return conf, nil
}
@@ -405,25 +402,11 @@ func ValidateUpstreams(upstreams []string) (err error) {
return err
}
// stringKeysSorted returns the sorted slice of string keys of m.
//
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
sorted = make([]string, 0, len(m))
for s := range m {
sorted = append(sorted, s)
}
sort.Strings(sorted)
return sorted
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. lnc must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return err
@@ -433,9 +416,11 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
return nil
}
var errs []error
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
var errs []error
for _, domain := range keys {
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain)
if err != nil {
@@ -444,7 +429,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
continue
}
if !lnc.IsLocallyServedNetwork(subnet.IP) {
if !privateNets.Contains(subnet.IP) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
@@ -461,16 +446,14 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
func validateUpstream(u string) (useDefault bool, err error) {
// Check if the user tries to specify upstream for domain.
var isDomainSpec bool
u, isDomainSpec, err = separateUpstream(u)
if err != nil {
return !isDomainSpec, err
}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
// domain-specific and is configured to point at the default upstream server
// which is validated separately. The upstream is considered domain-specific
// only if domains is at least not nil.
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
// The special server address '#' means that default server must be used.
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
if useDefault = u == "#" && domains != nil; useDefault {
return useDefault, nil
}
@@ -497,12 +480,14 @@ func validateUpstream(u string) (useDefault bool, err error) {
return useDefault, nil
}
// separateUpstream returns the upstream without the specified domains.
// isDomainSpec is true when the upstream is domains-specific.
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
// separateUpstream returns the upstream and the specified domains. domains is
// nil when the upstream is not domains-specific. Otherwise it may also be
// empty.
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, false, nil
return upstreamStr, nil, nil
}
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
parts := strings.Split(upstreamStr[2:], "/]")
@@ -510,39 +495,46 @@ func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, e
case 2:
// Go on.
case 1:
return "", false, errors.Error("missing separator")
return "", nil, errors.Error("missing separator")
default:
return "", true, errors.Error("duplicated separator")
return "", []string{}, errors.Error("duplicated separator")
}
var domains string
domains, upstream = parts[0], parts[1]
for i, host := range strings.Split(domains, "/") {
for i, host := range strings.Split(parts[0], "/") {
if host == "" {
continue
}
err = netutil.ValidateDomainName(host)
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
if err != nil {
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
}
domains = append(domains, host)
}
return upstream, true, nil
return parts[1], domains, nil
}
// excFunc is a signature of function to check if upstream exchanges correctly.
type excFunc func(u upstream.Upstream) (err error)
// healthCheckFunc is a signature of function to check if upstream exchanges
// properly.
type healthCheckFunc func(u upstream.Upstream) (err error)
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// testTLD is the special-use fully-qualified domain name for testing the
// DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
const testTLD = "test."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "google-public-dns-a.google.com.",
Name: testTLD,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
@@ -552,12 +544,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
reply, err = u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
}
if len(reply.Answer) != 1 {
return fmt.Errorf("wrong response")
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
return fmt.Errorf("wrong response")
} else if len(reply.Answer) != 0 {
return errors.Error("wrong response")
}
return nil
@@ -565,14 +553,22 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// checkPrivateUpstreamExc checks if the upstream for resolving private
// addresses exchanges correctly.
//
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
// address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
const inAddrArpaTLD = "in-addr.arpa."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "1.0.0.127.in-addr.arpa.",
Name: inAddrArpaTLD,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
@@ -585,46 +581,66 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
return nil
}
func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) {
if IsCommentOrEmpty(input) {
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
error
}
// checkDNS checks the upstream server defined by upstreamConfigStr using
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
// upstream's address.
func checkDNS(
upstreamConfigStr string,
bootstrap []string,
timeout time.Duration,
healthCheck healthCheckFunc,
) (err error) {
if IsCommentOrEmpty(upstreamConfigStr) {
return nil
}
// Separate upstream from domains list.
var useDefault bool
if useDefault, err = validateUpstream(input); err != nil {
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
// No need to check this DNS server.
if !useDefault {
useDefault, err := validateUpstream(upstreamAddr, domains)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
} else if useDefault {
return nil
}
if input, _, err = separateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
if len(bootstrap) == 0 {
bootstrap = defaultBootstrap
}
log.Debug("checking if upstream %s works", input)
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, &upstream.Options{
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
})
if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", input, err)
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
}
if err = ef(u); err != nil {
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
if err = healthCheck(u); err != nil {
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
if domains != nil {
return domainSpecificTestError{error: err}
}
return err
}
log.Debug("upstream %s is ok", input)
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
return nil
}
@@ -647,6 +663,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
if err != nil {
log.Info("%v", err)
result[host] = err.Error()
if _, ok := err.(domainSpecificTestError); ok {
result[host] = fmt.Sprintf("WARNING: %s", result[host])
}
continue
}
@@ -662,6 +681,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// above, we rewriting the error for it. These cases should be
// handled properly instead.
result[host] = err.Error()
if _, ok := err.(domainSpecificTestError); ok {
result[host] = fmt.Sprintf("WARNING: %s", result[host])
}
continue
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,7 +34,7 @@ func (fsr *fakeSystemResolvers) Get() (rs []string) {
return nil
}
func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
func loadTestData(t *testing.T, casesFileName string, cases any) {
t.Helper()
var f *os.File
@@ -184,7 +185,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "upstream_dns_bad",
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
wantSet: `validating upstream servers: ` +
`validating upstream "!!!": bad ipport address "!!!": ` +
`address !!!: missing port in address`,
}, {
name: "bootstraps_bad",
@@ -255,112 +257,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
}
}
func TestValidateUpstream(t *testing.T) {
testCases := []struct {
wantDef assert.BoolAssertionFunc
name string
upstream string
wantErr string
}{{
wantDef: assert.True,
name: "invalid",
upstream: "1.2.3.4.5",
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "123.3.7m",
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "htttps://google.com/dns-query",
wantErr: `wrong protocol`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[/host.com]tls://dns.adguard.com",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
}, {
wantDef: assert.True,
name: "invalid",
upstream: "[host.ru]#",
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "tls://1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "https://dns.adguard.com/dns-query",
wantErr: ``,
}, {
wantDef: assert.True,
name: "valid_default",
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
wantErr: ``,
}, {
wantDef: assert.True,
name: "default_udp_host",
upstream: "udp://dns.google",
}, {
wantDef: assert.True,
name: "default_udp_ip",
upstream: "udp://8.8.8.8",
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/]1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[//]tls://1.1.1.1",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/www.host.com/]#",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host.com/google.com/]8.8.8.8",
wantErr: ``,
}, {
wantDef: assert.False,
name: "valid",
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
wantErr: ``,
}, {
wantDef: assert.False,
name: "idna",
upstream: "[/пример.рф/]8.8.8.8",
wantErr: ``,
}, {
wantDef: assert.False,
name: "bad_domain",
upstream: "[/!/]8.8.8.8",
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defaultUpstream, err := validateUpstream(tc.upstream)
testutil.AssertErrorMsg(t, tc.wantErr, err)
tc.wantDef(t, defaultUpstream)
})
}
}
func TestValidateUpstreams(t *testing.T) {
testCases := []struct {
name string
@@ -375,7 +271,7 @@ func TestValidateUpstreams(t *testing.T) {
wantErr: ``,
set: []string{"# comment"},
}, {
name: "valid_no_default",
name: "no_default",
wantErr: `no default upstreams specified`,
set: []string{
"[/host.com/]1.1.1.1",
@@ -385,7 +281,7 @@ func TestValidateUpstreams(t *testing.T) {
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
},
}, {
name: "valid_with_default",
name: "with_default",
wantErr: ``,
set: []string{
"[/host.com/]1.1.1.1",
@@ -397,8 +293,46 @@ func TestValidateUpstreams(t *testing.T) {
},
}, {
name: "invalid",
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
set: []string{"dhcp://fake.dns"},
}, {
name: "invalid",
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
set: []string{"1.2.3.4.5"},
}, {
name: "invalid",
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
set: []string{"123.3.7m"},
}, {
name: "invalid",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
}, {
name: "invalid",
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
set: []string{"[host.ru]#"},
}, {
name: "valid_default",
wantErr: ``,
set: []string{
"1.1.1.1",
"tls://1.1.1.1",
"https://dns.adguard.com/dns-query",
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"udp://dns.google",
"udp://8.8.8.8",
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"[/пример.рф/]8.8.8.8",
},
}, {
name: "bad_domain",
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
set: []string{"[/!/]8.8.8.8"},
}}
for _, tc := range testCases {
@@ -410,8 +344,7 @@ func TestValidateUpstreams(t *testing.T) {
}
func TestValidateUpstreamsPrivate(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
@@ -452,7 +385,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err = ValidateUpstreamsPrivate(set, snd)
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}

View File

@@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) {
testCases := []struct {
name string
msg dns.Msg
want bool
want int
}{{
name: "simple",
msg: dns.Msg{
@@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) {
Qtype: dns.TypeA,
}},
},
want: true,
want: 1,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: false,
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
if tc.want {
assert.Equal(t, 1, rd.recentRequests.Stats().Count)
} else {
assert.Zero(t, rd.recentRequests.Stats().Count)
}
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}

View File

@@ -64,9 +64,9 @@ func (s *Server) logQuery(
Answer: pctx.Res,
OrigAnswer: dctx.origResp,
Result: dctx.result,
Elapsed: elapsed,
ClientID: dctx.clientID,
ClientIP: ip,
Elapsed: elapsed,
AuthenticatedData: dctx.responseAD,
}

View File

@@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// actually implementing all methods.
stats.Stats
stats.Interface
lastEntry stats.Entry
}

View File

@@ -32,12 +32,16 @@ func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTT
// github.com/miekg/dns module.
var strToSVCBKey = map[string]dns.SVCBKey{
"alpn": dns.SVCB_ALPN,
"echconfig": dns.SVCB_ECHCONFIG,
"ech": dns.SVCB_ECHCONFIG,
"ipv4hint": dns.SVCB_IPV4HINT,
"ipv6hint": dns.SVCB_IPV6HINT,
"mandatory": dns.SVCB_MANDATORY,
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
"port": dns.SVCB_PORT,
// TODO(a.garipov): This is the previous name for the parameter that has
// since been changed. Remove this in v0.109.0.
"echconfig": dns.SVCB_ECHCONFIG,
}
// svcbKeyHandler is a handler for one SVCB parameter key.
@@ -51,10 +55,10 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
}
},
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
"ech": func(valStr string) (val dns.SVCBKeyValue) {
ech, err := base64.StdEncoding.DecodeString(valStr)
if err != nil {
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
return nil
}
@@ -119,6 +123,32 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
Port: uint16(port64),
}
},
// TODO(a.garipov): This is the previous name for the parameter that has
// since been changed. Remove this in v0.109.0.
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
log.Info(
`warning: svcb/https record parameter name "echconfig" is deprecated; ` +
`use "ech" instead`,
)
ech, err := base64.StdEncoding.DecodeString(valStr)
if err != nil {
log.Debug("can't parse svcb/https ech: %s; ignoring", err)
return nil
}
return &dns.SVCBECHConfig{
ECH: ech,
}
},
"dohpath": func(valStr string) (val dns.SVCBKeyValue) {
return &dns.SVCBDoHPath{
Template: valStr,
}
},
}
// genAnswerSVCB returns a properly initialized SVCB resource record.

View File

@@ -87,14 +87,18 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
svcb: dnssvcb("alpn", "h3"),
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
name: "alpn",
}, {
svcb: dnssvcb("ech", "AAAA"),
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
name: "ech",
}, {
svcb: dnssvcb("echconfig", "AAAA"),
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
name: "echconfig",
name: "ech_deprecated",
}, {
svcb: dnssvcb("echconfig", "%BAD%"),
want: wantsvcb(nil),
name: "echconfig_invalid",
name: "ech_invalid",
}, {
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
@@ -123,6 +127,10 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
svcb: dnssvcb("no-default-alpn", ""),
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
name: "no_default_alpn",
}, {
svcb: dnssvcb("dohpath", "/dns-query"),
want: wantsvcb(&dns.SVCBDoHPath{Template: "/dns-query"}),
name: "dohpath",
}, {
svcb: dnssvcb("port", "8080"),
want: wantsvcb(&dns.SVCBPort{Port: 8080}),

View File

@@ -7,87 +7,181 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"golang.org/x/exp/slices"
)
var serviceRules map[string][]*rules.NetworkRule // service name -> filtering rules
// svc represents a single blocked service.
type svc struct {
name string
rules []string
}
// servicesData contains raw blocked service data.
//
// Keep in sync with:
// client/src/helpers/constants.js
// client/src/components/ui/Icons.js
var serviceRulesArray = []svc{
{"whatsapp", []string{"||whatsapp.net^", "||whatsapp.com^"}},
{"facebook", []string{
// - client/src/helpers/constants.js
// - client/src/components/ui/Icons.js
var servicesData = []svc{{
name: "whatsapp",
rules: []string{
"||wa.me^",
"||whatsapp.com^",
"||whatsapp.net^",
},
}, {
name: "facebook",
rules: []string{
"||facebook.com^",
"||facebook.net^",
"||fbcdn.net^",
"||accountkit.com^",
"||fb.me^",
"||fb.com^",
"||fb.gg^",
"||fbsbx.com^",
"||fbwat.ch^",
"||messenger.com^",
"||facebookcorewwwi.onion^",
"||fbcdn.com^",
"||fb.watch^",
}},
{"twitter", []string{"||twitter.com^", "||twttr.com^", "||t.co^", "||twimg.com^"}},
{"youtube", []string{
"||youtube.com^",
"||ytimg.com^",
"||youtu.be^",
},
}, {
name: "twitter",
rules: []string{
"||t.co^",
"||twimg.com^",
"||twitter.com^",
"||twttr.com^",
},
}, {
name: "youtube",
rules: []string{
"||googlevideo.com^",
"||youtubei.googleapis.com^",
"||youtube-nocookie.com^",
"||wide-youtube.l.google.com^",
"||youtu.be^",
"||youtube",
}},
{"twitch", []string{"||twitch.tv^", "||ttvnw.net^", "||jtvnw.net^", "||twitchcdn.net^"}},
{"netflix", []string{"||nflxext.com^", "||netflix.com^", "||nflximg.net^", "||nflxvideo.net^", "||nflxso.net^"}},
{"instagram", []string{"||instagram.com^", "||cdninstagram.com^"}},
{"snapchat", []string{
"||youtube-nocookie.com^",
"||youtube.com^",
"||youtubei.googleapis.com^",
"||youtubekids.com^",
"||ytimg.com^",
},
}, {
name: "twitch",
rules: []string{
"||jtvnw.net^",
"||ttvnw.net^",
"||twitch.tv^",
"||twitchcdn.net^",
},
}, {
name: "netflix",
rules: []string{
"||nflxext.com^",
"||netflix.com^",
"||nflximg.net^",
"||nflxvideo.net^",
"||nflxso.net^",
},
}, {
name: "instagram",
rules: []string{"||instagram.com^", "||cdninstagram.com^"},
}, {
name: "snapchat",
rules: []string{
"||snapchat.com^",
"||sc-cdn.net^",
"||snap-dev.net^",
"||snapkit.co",
"||snapads.com^",
"||impala-media-production.s3.amazonaws.com^",
}},
{"discord", []string{"||discord.gg^", "||discordapp.net^", "||discordapp.com^", "||discord.com^", "||discord.media^"}},
{"ok", []string{"||ok.ru^"}},
{"skype", []string{"||skype.com^", "||skypeassets.com^"}},
{"vk", []string{"||vk.com^", "||userapi.com^", "||vk-cdn.net^", "||vkuservideo.net^"}},
{"origin", []string{"||origin.com^", "||signin.ea.com^", "||accounts.ea.com^"}},
{"steam", []string{
},
}, {
name: "discord",
rules: []string{
"||discord.gg^",
"||discordapp.net^",
"||discordapp.com^",
"||discord.com^",
"||discord.gift",
"||discord.media^",
},
}, {
name: "ok",
rules: []string{"||ok.ru^"},
}, {
name: "skype",
rules: []string{
"||edge-skype-com.s-0001.s-msedge.net^",
"||skype-edf.akadns.net^",
"||skype.com^",
"||skypeassets.com^",
"||skypedata.akadns.net^",
},
}, {
name: "vk",
rules: []string{
"||userapi.com^",
"||vk-cdn.net^",
"||vk.com^",
"||vkuservideo.net^",
},
}, {
name: "origin",
rules: []string{
"||accounts.ea.com^",
"||origin.com^",
"||signin.ea.com^",
},
}, {
name: "steam",
rules: []string{
"||steam.com^",
"||steampowered.com^",
"||steamcommunity.com^",
"||steamstatic.com^",
"||steamstore-a.akamaihd.net^",
"||steamcdn-a.akamaihd.net^",
}},
{"epic_games", []string{"||epicgames.com^", "||easyanticheat.net^", "||easy.ac^", "||eac-cdn.com^"}},
{"reddit", []string{"||reddit.com^", "||redditstatic.com^", "||redditmedia.com^", "||redd.it^"}},
{"mail_ru", []string{"||mail.ru^"}},
{"cloudflare", []string{
"||cloudflare.com^",
"||cloudflare-dns.com^",
"||cloudflare.net^",
"||cloudflareinsights.com^",
"||cloudflarestream.com^",
"||cloudflareresolve.com^",
"||cloudflareclient.com^",
"||cloudflarebolt.com^",
"||cloudflarestatus.com^",
"||cloudflare.cn^",
"||one.one^",
"||warp.plus^",
},
}, {
name: "epic_games",
rules: []string{"||epicgames.com^", "||easyanticheat.net^", "||easy.ac^", "||eac-cdn.com^"},
}, {
name: "reddit",
rules: []string{"||reddit.com^", "||redditstatic.com^", "||redditmedia.com^", "||redd.it^"},
}, {
name: "mail_ru",
rules: []string{"||mail.ru^"},
}, {
name: "cloudflare",
rules: []string{
"||1.1.1.1^",
"||argotunnel.com^",
"||cloudflare-dns.com^",
"||cloudflare-ipfs.com^",
"||cloudflare-quic.com^",
"||cloudflare.cn^",
"||cloudflare.com^",
"||cloudflare.net^",
"||cloudflareapps.com^",
"||cloudflarebolt.com^",
"||cloudflareclient.com^",
"||cloudflareinsights.com^",
"||cloudflareresolve.com^",
"||cloudflarestatus.com^",
"||cloudflarestream.com^",
"||cloudflarewarp.com^",
"||dns4torpnlfs2ifuz2s2yf3fc7rdmsbhm6rw75euj35pac6ap25zgqad.onion^",
}},
{"amazon", []string{
"||one.one^",
"||pages.dev^",
"||trycloudflare.com^",
"||videodelivery.net^",
"||warp.plus^",
"||workers.dev^",
},
}, {
name: "amazon",
rules: []string{
"||amazon.com^",
"||media-amazon.com^",
"||primevideo.com^",
@@ -111,11 +205,14 @@ var serviceRulesArray = []svc{
"||amazon.com.br^",
"||amazon.co.jp^",
"||amazon.com.mx^",
"||amazon.com.tr^",
"||amazon.co.uk^",
"||createspace.com^",
"||aws",
}},
{"ebay", []string{
},
}, {
name: "ebay",
rules: []string{
"||ebay.com^",
"||ebayimg.com^",
"||ebaystatic.com^",
@@ -141,8 +238,10 @@ var serviceRulesArray = []svc{
"||ebay.com.my^",
"||ebay.com.sg^",
"||ebay.co.uk^",
}},
{"tiktok", []string{
},
}, {
name: "tiktok",
rules: []string{
"||tiktok.com^",
"||tiktokcdn.com^",
"||musical.ly^",
@@ -156,65 +255,95 @@ var serviceRulesArray = []svc{
"||toutiaocloud.net^",
"||bdurl.com^",
"||bytecdn.cn^",
"||bytedapm.com^",
"||byteimg.com^",
"||byteoversea.com^",
"||ixigua.com^",
"||muscdn.com^",
"||bytedance.map.fastly.net^",
"||douyin.com^",
"||tiktokv.com^",
}},
{"vimeo", []string{
"||toutiaovod.com^",
"||douyincdn.com^",
},
}, {
name: "vimeo",
rules: []string{
"*vod-adaptive.akamaized.net^",
"||vimeo.com^",
"||vimeocdn.com^",
"*vod-adaptive.akamaized.net^",
}},
{"pinterest", []string{
"||pinterest.*^",
},
}, {
name: "pinterest",
rules: []string{
"||pinimg.com^",
}},
{"imgur", []string{
"||imgur.com^",
}},
{"dailymotion", []string{
"||pinterest.*^",
},
}, {
name: "imgur",
rules: []string{"||imgur.com^"},
}, {
name: "dailymotion",
rules: []string{
"||dailymotion.com^",
"||dm-event.net^",
"||dmcdn.net^",
}},
{"qq", []string{
// block qq.com and subdomains excluding WeChat domains
"||qq.com^$denyallow=wx*.qq.com|weixin.qq.com",
},
}, {
name: "qq",
rules: []string{
// Block qq.com and subdomains excluding WeChat's domains.
"||qq.com^$denyallow=wx.qq.com|weixin.qq.com",
"||qqzaixian.com^",
}},
{"wechat", []string{
"||qq-video.cdn-go.cn^",
"||url.cn^",
},
}, {
name: "wechat",
rules: []string{
"||wechat.com^",
"||weixin.qq.com.cn^",
"||weixin.qq.com^",
"||weixinbridge.com^",
"||wx.qq.com^",
}},
{"viber", []string{
"||viber.com^",
}},
{"weibo", []string{
},
}, {
name: "viber",
rules: []string{"||viber.com^"},
}, {
name: "weibo",
rules: []string{
"||weibo.cn^",
"||weibo.com^",
}},
{"9gag", []string{
"||weibocdn.com^",
},
}, {
name: "9gag",
rules: []string{
"||9cache.com^",
"||9gag.com^",
}},
{"telegram", []string{
},
}, {
name: "telegram",
rules: []string{
"||t.me^",
"||telegram.me^",
"||telegram.org^",
}},
{"disneyplus", []string{
},
}, {
name: "disneyplus",
rules: []string{
"||disney-plus.net^",
"||disneyplus.com^",
"||disney.playback.edge.bamgrid.com^",
"||media.dssott.com^",
}},
{"hulu", []string{
"||hulu.com^",
}},
{"spotify", []string{
},
}, {
name: "hulu",
rules: []string{"||hulu.com^"},
}, {
name: "spotify",
rules: []string{
"/_spotify-connect._tcp.local/",
"||spotify.com^",
"||scdn.co^",
@@ -226,29 +355,59 @@ var serviceRulesArray = []svc{
"||audio4-ak-spotify-com.akamaized.net^",
"||heads-ak-spotify-com.akamaized.net^",
"||heads4-ak-spotify-com.akamaized.net^",
}},
{"tinder", []string{
},
}, {
name: "tinder",
rules: []string{
"||gotinder.com^",
"||tinder.com^",
"||tindersparks.com^",
}},
}
},
}, {
name: "bilibili",
rules: []string{
"||biliapi.net^",
"||bilibili.com^",
"||biligame.com^",
"||bilivideo.cn^",
"||bilivideo.com^",
"||dreamcast.hk^",
"||hdslb.com^",
},
}}
// convert array to map
// serviceRules maps a service ID to its filtering rules.
var serviceRules map[string][]*rules.NetworkRule
// serviceIDs contains service IDs sorted alphabetically.
var serviceIDs []string
// initBlockedServices initializes package-level blocked service data.
func initBlockedServices() {
serviceRules = make(map[string][]*rules.NetworkRule)
for _, s := range serviceRulesArray {
netRules := []*rules.NetworkRule{}
l := len(servicesData)
serviceIDs = make([]string, l)
serviceRules = make(map[string][]*rules.NetworkRule, l)
for i, s := range servicesData {
netRules := make([]*rules.NetworkRule, 0, len(s.rules))
for _, text := range s.rules {
rule, err := rules.NewNetworkRule(text, BlockedSvcsListID)
if err != nil {
log.Error("rules.NewNetworkRule: %s rule: %s", err, text)
log.Error("parsing blocked service %q rule %q: %s", s.name, text, err)
continue
}
netRules = append(netRules, rule)
}
serviceIDs[i] = s.name
serviceRules[s.name] = netRules
}
slices.Sort(serviceIDs)
log.Debug("filtering: initialized %d services", l)
}
// BlockedSvcKnown - return TRUE if a blocked service name is known
@@ -280,6 +439,16 @@ func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global
}
}
func (d *DNSFilter) handleBlockedServicesAvailableServices(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(serviceIDs)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding available services: %s", err)
return
}
}
func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
d.confLock.RLock()
list := d.Config.BlockedServices
@@ -288,7 +457,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(list)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding services: %s", err)
return
}
@@ -314,6 +483,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
// registerBlockedServicesHandlers - register HTTP handlers
func (d *DNSFilter) registerBlockedServicesHandlers() {
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
}

View File

@@ -61,22 +61,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
testCasesA := []struct {
name string
want []interface{}
want []any
rcode int
dtyp uint16
}{{
name: "a-record",
rcode: dns.RcodeSuccess,
want: []interface{}{ipv4p1},
want: []any{ipv4p1},
dtyp: dns.TypeA,
}, {
name: "aaaa-record",
want: []interface{}{ipv6p1},
want: []any{ipv6p1},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeAAAA,
}, {
name: "txt-record",
want: []interface{}{"hello-world"},
want: []any{"hello-world"},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeTXT,
}, {
@@ -86,22 +86,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
dtyp: 0,
}, {
name: "a-records",
want: []interface{}{ipv4p1, ipv4p2},
want: []any{ipv4p1, ipv4p2},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeA,
}, {
name: "aaaa-records",
want: []interface{}{ipv6p1, ipv6p2},
want: []any{ipv6p1, ipv6p2},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeAAAA,
}, {
name: "disable-one",
want: []interface{}{ipv4p2},
want: []any{ipv4p2},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeA,
}, {
name: "disable-cname",
want: []interface{}{ipv4p1},
want: []any{ipv4p1},
rcode: dns.RcodeSuccess,
dtyp: dns.TypeA,
}}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io/fs"
"net"
"net/http"
"os"
"runtime"
"runtime/debug"
@@ -14,6 +13,7 @@ import (
"sync"
"sync/atomic"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
@@ -94,7 +94,7 @@ type Config struct {
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"`
@@ -296,9 +296,11 @@ func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
return clone
}
// SetFilters - set new filters (synchronously or asynchronously)
// When filters are set asynchronously, the old filters continue working until the new filters are ready.
// In this case the caller must ensure that the old filter files are intact.
// SetFilters sets new filters, synchronously or asynchronously. When filters
// are set asynchronously, the old filters continue working until the new
// filters are ready.
//
// In this case the caller must ensure that the old filter files are intact.
func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool) error {
if async {
params := filtersInitializerParams{
@@ -471,7 +473,7 @@ func (d *DNSFilter) matchSysHosts(
return res, nil
}
dnsres, _ := d.EtcHosts.MatchRequest(urlfilter.DNSRequest{
dnsres, _ := d.EtcHosts.MatchRequest(&urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
@@ -802,7 +804,7 @@ func (d *DNSFilter) matchHost(
return Result{}, nil
}
ureq := urlfilter.DNSRequest{
ureq := &urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.

View File

@@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"
)
var setts = Settings{
ProtectionEnabled: true,
}
@@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+matching)
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatchEmpty(t, pcBlocked)
// Cached result.
d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, matching)
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatch(t, sbBlocked)
d.checkMatchEmpty(t, pcBlocked)
d.safeBrowsingServer = defaultSafebrowsingServer
}
func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel()
d.checkMatch(t, matching)
d.checkMatch(t, "test."+matching)
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatchEmpty(t, pcBlocked)
})
}
})
@@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.checkMatch(t, pcBlocked)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
d.checkMatch(t, "www."+matching)
d.checkMatch(t, "www."+pcBlocked)
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com")
// Test cached result.
d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching)
d.checkMatch(t, pcBlocked)
d.checkMatchEmpty(t, "yandex.ru")
}
@@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
}, {
name: "sanity",
rules: "||doubleclick.net^",
host: "wmconvirus.narod.ru",
host: sbBlocked,
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
wantDNSType: dns.TypeA,
@@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
}},
)
t.Cleanup(d.Close)
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: "pornhub.com",
Block: true,
})
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: "wmconvirus.narod.ru",
Block: true,
})
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
type testCase struct {
name string
@@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
wantReason: FilteredBlockList,
}, {
name: "parental",
host: "pornhub.com",
host: pcBlocked,
before: true,
wantReason: FilteredParental,
}, {
name: "safebrowsing",
host: "wmconvirus.narod.ru",
host: sbBlocked,
before: false,
wantReason: FilteredSafeBrowsing,
}, {
@@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: blocked,
Block: true,
})
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
}
}
func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: blocked,
Block: true,
})
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
}
})
}

View File

@@ -130,10 +130,9 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
//
// The sorting priority:
//
// A and AAAA > CNAME
// wildcard > exact
// lower level wildcard > higher level wildcard
//
// 1. A and AAAA > CNAME
// 2. wildcard > exact
// 3. lower level wildcard > higher level wildcard
type rewritesSorted []*LegacyRewrite
// Len implements the sort.Interface interface for legacyRewritesSorted.

View File

@@ -24,10 +24,11 @@ import (
// Safe browsing and parental control methods.
// TODO(a.garipov): Make configurable.
const (
dnsTimeout = 3 * time.Second
defaultSafebrowsingServer = `https://dns-family.adguard.com/dns-query`
defaultParentalServer = `https://dns-family.adguard.com/dns-query`
defaultSafebrowsingServer = `https://family.adguard-dns.com/dns-query`
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
sbTXTSuffix = `sb.dns.adguard.com.`
pcTXTSuffix = `pc.dns.adguard.com.`
)
@@ -313,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
defer timer.LogElapsed("safebrowsing lookup for %q", host)
}
sctx := &sbCtx{
@@ -347,7 +348,7 @@ func (d *DNSFilter) checkParental(
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host)
defer timer.LogElapsed("parental lookup for %q", host)
}
sctx := &sbCtx{

View File

@@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached())
// match "sub.host.com" from cache,
// but another hash for "nonexisting.com" is not in cache
// which means that we must get data from server for it
// Match "sub.host.com" from cache. Another hash for "host.example" is not
// in the cache, so get data for it from the server.
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com"
hash = sha256.Sum256([]byte("host.example"))
c.hashToHost[hash] = "host.example"
assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
assert.False(t, ok)
hash = sha256.Sum256([]byte("nonexisting.com"))
hash = sha256.Sum256([]byte("host.example"))
_, ok = c.hashToHost[hash]
assert.True(t, ok)
@@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := &aghtest.TestErrUpstream{}
ups := aghtest.NewErrorUpstream()
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
@@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
for _, tc := range testCases {
// Prepare the upstream.
ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
ups := aghtest.NewBlockUpstream(hostname, tc.block)
var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++
return onExchange(req)
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
@@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits, tc.testCache.Stats().Hit)
// There was one request to an upstream.
assert.Equal(t, 1, ups.RequestsCount())
assert.Equal(t, 1, numReq)
// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname, dns.TypeA, setts)
@@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
// Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount())
assert.Equal(t, 1, numReq)
})
purgeCaches(d)

View File

@@ -2,9 +2,11 @@ package home
import (
"bytes"
"encoding"
"fmt"
"net"
"sort"
"strings"
"sync"
"time"
@@ -53,12 +55,49 @@ type clientSource uint
// Client sources. The order determines the priority.
const (
ClientSourceWHOIS clientSource = iota
ClientSourceRDNS
ClientSourceARP
ClientSourceRDNS
ClientSourceDHCP
ClientSourceHostsFile
)
var _ fmt.Stringer = clientSource(0)
// String returns a human-readable name of cs.
func (cs clientSource) String() (s string) {
switch cs {
case ClientSourceWHOIS:
return "WHOIS"
case ClientSourceARP:
return "ARP"
case ClientSourceRDNS:
return "rDNS"
case ClientSourceDHCP:
return "DHCP"
case ClientSourceHostsFile:
return "etc/hosts"
default:
return ""
}
}
var _ encoding.TextMarshaler = clientSource(0)
// MarshalText implements encoding.TextMarshaler for the clientSource.
func (cs clientSource) MarshalText() (text []byte, err error) {
return []byte(cs.String()), nil
}
// clientSourceConf is used to configure where the runtime clients will be
// obtained from.
type clientSourcesConf struct {
WHOIS bool `yaml:"whois"`
ARP bool `yaml:"arp"`
RDNS bool `yaml:"rdns"`
DHCP bool `yaml:"dhcp"`
HostsFile bool `yaml:"hosts"`
}
// RuntimeClient information
type RuntimeClient struct {
WHOISInfo *RuntimeClientWHOISInfo
@@ -134,14 +173,14 @@ func (clients *clientsContainer) Init(
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
go clients.handleHostsUpdates()
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
}
func (clients *clientsContainer) handleHostsUpdates() {
if clients.etcHosts != nil {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
}
@@ -158,7 +197,9 @@ func (clients *clientsContainer) Start() {
// Reload reloads runtime clients.
func (clients *clientsContainer) Reload() {
clients.addFromSystemARP()
if clients.arpdb != nil {
clients.addFromSystemARP()
}
}
type clientObject struct {
@@ -257,6 +298,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
}
func (clients *clientsContainer) periodicUpdate() {
defer log.OnPanic("clients container")
for {
clients.Reload()
time.Sleep(clientsUpdatePeriod)
@@ -382,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
c.Tags = stringutil.CloneSlice(c.Tags)
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
return c, true
}
@@ -478,7 +522,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
var v interface{}
var v any
v, ok = clients.ipToRC.Get(ip)
if !ok {
return nil, false
@@ -532,7 +576,7 @@ func (clients *clientsContainer) check(c *Client) (err error) {
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
c.IDs[i] = id
c.IDs[i] = strings.ToLower(id)
} else {
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
}
@@ -723,20 +767,18 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc
clients.lock.Lock()
defer clients.lock.Unlock()
ok = clients.addHostLocked(ip, host, src)
return ok, nil
return clients.addHostLocked(ip, host, src), nil
}
// addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
var rc *RuntimeClient
rc, ok = clients.findRuntimeClientLocked(ip)
rc, ok := clients.findRuntimeClientLocked(ip)
if ok {
if rc.Source > src {
return false
}
rc.Host = host
rc.Source = src
} else {
rc = &RuntimeClient{
@@ -756,7 +798,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
@@ -784,26 +826,21 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
hosts, ok := v.(*stringutil.Set)
hosts.Range(func(ip net.IP, v any) (cont bool) {
rec, ok := v.(*aghnet.HostsRecord)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
return true
}
hosts.Range(func(name string) (cont bool) {
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
n++
}
return true
})
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
n++
return true
})
log.Debug("clients: added %d client aliases from system hosts-file", n)
log.Debug("clients: added %d client aliases from system hosts file", n)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@@ -831,7 +868,7 @@ func (clients *clientsContainer) addFromSystemARP() {
added := 0
for _, n := range ns {
if clients.addHostLocked(n.IP, "", ClientSourceARP) {
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
added++
}
}
@@ -842,7 +879,7 @@ func (clients *clientsContainer) addFromSystemARP() {
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) updateFromDHCP(add bool) {
if clients.dhcpServer == nil {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}

View File

@@ -47,9 +47,9 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
Source string `json:"source"`
IP net.IP `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
@@ -70,7 +70,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.Clients = append(data.Clients, cj)
}
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
@@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IP: ip,
}
cj.Source = "etc/hosts"
switch rc.Source {
case ClientSourceDHCP:
cj.Source = "DHCP"
case ClientSourceRDNS:
cj.Source = "rDNS"
case ClientSourceARP:
cj.Source = "ARP"
case ClientSourceWHOIS:
cj.Source = "WHOIS"
Name: rc.Host,
Source: rc.Source,
IP: ip,
}
data.RuntimeClients = append(data.RuntimeClients, cj)
@@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
return
}
@@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
rc, ok := clients.FindRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@@ -1,6 +1,7 @@
package home
import (
"bytes"
"fmt"
"net"
"os"
@@ -19,7 +20,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
const (
@@ -27,15 +28,36 @@ const (
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// logSettings
// logSettings are the logging settings part of the configuration file.
//
// TODO(a.garipov): Put them into a separate object.
type logSettings struct {
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
// File is the path to the log file. If empty, logs are written to stdout.
// If "syslog", logs are written to syslog.
File string `yaml:"log_file"`
// MaxBackups is the maximum number of old log files to retain.
//
// NOTE: MaxAge may still cause them to get deleted.
MaxBackups int `yaml:"log_max_backups"`
// MaxSize is the maximum size of the log file before it gets rotated, in
// megabytes. The default value is 100 MB.
MaxSize int `yaml:"log_max_size"`
// MaxAge is the maximum duration for retaining old log files, in days.
MaxAge int `yaml:"log_max_age"`
// Compress determines, if the rotated log files should be compressed using
// gzip.
Compress bool `yaml:"log_compress"`
// LocalTime determines, if the time used for formatting the timestamps in
// is the computer's local time.
LocalTime bool `yaml:"log_localtime"`
// Verbose determines, if verbose (aka debug) logging is enabled.
Verbose bool `yaml:"verbose"`
}
// osConfig contains OS-related configuration.
@@ -51,6 +73,13 @@ type osConfig struct {
RlimitNoFile uint64 `yaml:"rlimit_nofile"`
}
type clientsConfig struct {
// Sources defines the set of sources to fetch the runtime clients from.
Sources *clientSourcesConf `yaml:"runtime_sources"`
// Persistent are the configured clients.
Persistent []*clientObject `yaml:"persistent"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
@@ -88,7 +117,7 @@ type configuration struct {
// Clients contains the YAML representations of the persistent clients.
// This field is only used for reading and writing persistent client data.
// Keep this field sorted to ensure consistent ordering.
Clients []*clientObject `yaml:"clients"`
Clients *clientsConfig `yaml:"clients"`
logSettings `yaml:",inline"`
@@ -123,8 +152,9 @@ type dnsConfig struct {
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
// ResolveClients enables and disables resolving clients with RDNS.
ResolveClients bool `yaml:"resolve_clients"`
// PrivateNets is the set of IP networks for which the private reverse DNS
// resolver should be used.
PrivateNets []string `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
@@ -179,6 +209,7 @@ var config = &configuration{
Ratelimit: 20,
RefuseAny: true,
AllServers: false,
HandleDDR: true,
FastestTimeout: timeutil.Duration{
Duration: fastip.DefaultPingWaitTimeout,
},
@@ -194,7 +225,6 @@ var config = &configuration{
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
ResolveClients: true,
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
@@ -205,12 +235,21 @@ var config = &configuration{
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
},
Clients: &clientsConfig{
Sources: &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: true,
DHCP: true,
HostsFile: true,
},
},
logSettings: logSettings{
LogCompress: false,
LogLocalTime: false,
LogMaxBackups: 0,
LogMaxSize: 100,
LogMaxAge: 3,
Compress: false,
LocalTime: false,
MaxBackups: 0,
MaxSize: 100,
MaxAge: 3,
},
OSConfig: &osConfig{},
SchemaVersion: currentSchemaVersion,
@@ -285,25 +324,28 @@ func parseConfig() (err error) {
return err
}
uc := aghalg.UniqChecker{}
addPorts(
uc,
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
)
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort), tcpPort(config.BetaBindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
// TODO(e.burkov): Consider adding a udpPort with the same value when
// we add support for HTTP/3 for web admin interface.
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
@@ -317,8 +359,14 @@ func parseConfig() (err error) {
return nil
}
// addPorts is a helper for ports validation. It skips zero ports.
func addPorts(uc aghalg.UniqChecker, ports ...int) {
// udpPort is the port number for UDP protocol.
type udpPort int
// tcpPort is the port number for TCP protocol.
type tcpPort int
// addPorts is a helper for ports validation that skips zero ports.
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
for _, p := range ports {
if p != 0 {
uc.Add(p)
@@ -340,13 +388,14 @@ func readConfigFile() (fileData []byte, err error) {
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() error {
func (c *configuration) write() (err error) {
c.Lock()
defer c.Unlock()
if Context.auth != nil {
config.Users = Context.auth.GetUsers()
}
if Context.tls != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
@@ -380,9 +429,7 @@ func (c *configuration) write() error {
s.WriteDiskConfig(&c)
dns := &config.DNS
dns.FilteringConfig = c
dns.LocalPTRResolvers,
dns.ResolveClients,
dns.UsePrivateRDNS = s.RDNSSettings()
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
}
if Context.dhcpServer != nil {
@@ -391,22 +438,23 @@ func (c *configuration) write() error {
config.DHCP = c
}
config.Clients = Context.clients.forConfig()
config.Clients.Persistent = Context.clients.forConfig()
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)
if err != nil {
log.Error("Couldn't generate YAML file: %s", err)
log.Debug("writing config file %q", configFile)
return err
buf := &bytes.Buffer{}
enc := yaml.NewEncoder(buf)
enc.SetIndent(2)
err = enc.Encode(config)
if err != nil {
return fmt.Errorf("generating config file: %w", err)
}
err = maybe.WriteFile(configFile, yamlText, 0o644)
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
if err != nil {
log.Error("Couldn't save YAML config: %s", err)
return err
return fmt.Errorf("writing config file: %w", err)
}
return nil

View File

@@ -189,7 +189,7 @@ func registerControlHandlers() {
RegisterAuthHandlers()
}
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@@ -57,8 +58,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
err = validateFilterURL(fj.URL)
if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
err = fmt.Errorf("invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -178,16 +179,16 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
}
}
type filterURLJSON struct {
type filterURLReqData struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type filterURLReq struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
Data filterURLJSON `json:"data"`
Data *filterURLReqData `json:"data"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
@@ -199,10 +200,17 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
return
}
if fj.Data == nil {
err = errors.Error("data cannot be null")
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = validateFilterURL(fj.Data.URL)
if err != nil {
msg := fmt.Sprintf("invalid url: %s", err)
http.Error(w, msg, http.StatusBadRequest)
err = fmt.Errorf("invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -223,11 +231,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
}
onConfigModified()
restart := false
if (status & statusEnabledChanged) != 0 {
// we must add or remove filter rules
restart = true
}
restart := (status & statusEnabledChanged) != 0
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := filterRefreshBlocklists
@@ -242,6 +247,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
restart = true
}
}
if restart {
enableFilters(true)
}
@@ -311,20 +317,20 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
LastUpdated string `json:"last_updated,omitempty"`
ID int64 `json:"id"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
Enabled bool `json:"enabled"`
}
type filteringConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
Interval uint32 `json:"interval"` // in hours
Enabled bool `json:"enabled"`
}
func filterToJSON(f filter) filterJSON {
@@ -402,16 +408,12 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
}
type checkHostRespRule struct {
FilterListID int64 `json:"filter_list_id"`
Text string `json:"text"`
FilterListID int64 `json:"filter_list_id"`
}
type checkHostResp struct {
Reason string `json:"reason"`
// FilterID is the ID of the rule's filter list.
//
// Deprecated: Use Rules[*].FilterListID.
FilterID int64 `json:"filter_id"`
// Rule is the text of the matched rule.
//
@@ -426,6 +428,11 @@ type checkHostResp struct {
// for Rewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
// FilterID is the ID of the rule's filter list.
//
// Deprecated: Use Rules[*].FilterListID.
FilterID int64 `json:"filter_id"`
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {

View File

@@ -105,19 +105,22 @@ type checkConfResp struct {
// validateWeb returns error is the web part if the initial configuration can't
// be set.
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.Web.Port
addPorts(uc, config.BetaBindPort, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
// Avoid duplicating the error into the status of DNS.
uc[port] = 1
portInt := req.Web.Port
port := tcpPort(portInt)
addPorts(tcpPorts, tcpPort(config.BetaBindPort), port)
if err = tcpPorts.Validate(); err != nil {
// Reset the value for the port to 1 to make sure that validateDNS
// doesn't throw the same error, unless the same TCP port is set there
// as well.
tcpPorts[port] = 1
return err
}
switch port {
switch portInt {
case 0, config.BindPort:
return nil
default:
@@ -125,21 +128,18 @@ func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
// unbound after install.
}
return aghnet.CheckPort("tcp", req.Web.IP, port)
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
}
// validateDNS returns error if the DNS part of the initial configuration can't
// be set. canAutofix is true if the port can be unbound by AdGuard Home
// automatically.
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
func (req *checkConfReq) validateDNS(
tcpPorts aghalg.UniqChecker[tcpPort],
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.DNS.Port
addPorts(uc, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return false, err
}
switch port {
case 0:
return false, nil
@@ -148,6 +148,11 @@ func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, er
// by AdGuard Home for web interface.
default:
// Check TCP as well.
addPorts(tcpPorts, tcpPort(port))
if err = tcpPorts.Validate(); err != nil {
return false, err
}
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
if err != nil {
return false, err
@@ -185,13 +190,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
}
resp := &checkConfResp{}
uc := aghalg.UniqChecker{}
if err = req.validateWeb(uc); err != nil {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
if err = req.validateWeb(tcpPorts); err != nil {
resp.Web.Status = err.Error()
}
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
@@ -212,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := aghnet.GetInterfaceByIP(ip)
interfaceName := aghnet.InterfaceByIP(ip)
resp.Static = "no"
if len(interfaceName) == 0 {

View File

@@ -3,14 +3,15 @@ package home
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
@@ -27,12 +28,16 @@ type temporaryError interface {
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
resp := &versionResponse{}
if Context.disableUpdate {
// w.Header().Set("Content-Type", "application/json")
resp.Disabled = true
_ = json.NewEncoder(w).Encode(resp)
// TODO(e.burkov): Add error handling and deal with headers.
err := json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
return
}
@@ -44,30 +49,48 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "parsing request: %s", err)
return
}
}
err = requestVersionInfo(resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
return
}
err = resp.setAllowedToAutoUpdate()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
// update server.
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
for i := 0; i != 3; i++ {
func() {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
}()
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
if err != nil {
var terr temporaryError
if errors.As(err, &terr) && terr.Temporary() {
// Temporary network error. This case may happen while
// we're restarting our DNS server. Log and sleep for
// some time.
// Temporary network error. This case may happen while we're
// restarting our DNS server. Log and sleep for some time.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
d := time.Duration(i) * time.Second
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
time.Sleep(d)
continue
@@ -76,29 +99,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
break
}
if err != nil {
vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb.
aghhttp.Error(
r,
w,
http.StatusBadGateway,
"Couldn't get version check json from %s: %T %s\n",
vcu,
err,
err,
)
return
return fmt.Errorf("getting version info from %s: %s", vcu, err)
}
resp.confirmAutoUpdate()
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
return nil
}
// handleUpdate performs an update to the latest available version procedure.
@@ -109,7 +117,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
return
}
err := Context.updater.Update()
// Retain the current absolute path of the executable, since the updater is
// likely to change the position current one to the backup directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4735.
execPath, err := os.Executable()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "getting path: %s", err)
return
}
err = Context.updater.Update()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -121,85 +140,88 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
f.Flush()
}
// The background context is used because the underlying functions wrap
// it with timeout and shut down the server, which handles current
// request. It also should be done in a separate goroutine due to the
// same reason.
go func() {
finishUpdate(context.Background())
}()
// The background context is used because the underlying functions wrap it
// with timeout and shut down the server, which handles current request. It
// also should be done in a separate goroutine for the same reason.
go finishUpdate(context.Background(), execPath)
}
// versionResponse is the response for /control/version.json endpoint.
type versionResponse struct {
Disabled bool `json:"disabled"`
updater.VersionInfo
Disabled bool `json:"disabled"`
}
// confirmAutoUpdate checks the real possibility of auto update.
func (vr *versionResponse) confirmAutoUpdate() {
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
canUpdate := true
var tlsConf *tlsConfigSettings
if runtime.GOOS != "windows" {
tlsConf = &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
}
if tlsConf != nil &&
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
tlsConf.PortDNSOverTLS < 1024 ||
tlsConf.PortDNSOverQUIC < 1024)) ||
config.BindPort < 1024 ||
config.DNS.Port < 1024) {
canUpdate, _ = aghnet.CanBindPrivilegedPorts()
}
vr.CanAutoUpdate = &canUpdate
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS.
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
if vr.CanAutoUpdate != aghalg.NBTrue {
return nil
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
}
}
vr.CanAutoUpdate = aghalg.BoolToNullBool(canUpdate)
return nil
}
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
}
// finishUpdate completes an update procedure.
func finishUpdate(ctx context.Context) {
log.Info("Stopping all tasks")
func finishUpdate(ctx context.Context, execPath string) {
var err error
log.Info("stopping all tasks")
cleanup(ctx)
cleanupAlways()
exeName := "AdGuardHome"
if runtime.GOOS == "windows" {
exeName = "AdGuardHome.exe"
}
curBinName := filepath.Join(Context.workDir, exeName)
if runtime.GOOS == "windows" {
if Context.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it
// NOTE: We can't restart the service via "kardianos/service"
// package, because it kills the process first we can't start a new
// instance, because Windows doesn't allow it.
//
// TODO(a.garipov): Recheck the claim above.
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err := cmd.Start()
err = cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
log.Fatalf("restarting: stopping: %s", err)
}
os.Exit(0)
}
cmd := exec.Command(curBinName, os.Args[1:]...)
log.Info("Restarting: %v", cmd.Args)
cmd := exec.Command(execPath, os.Args[1:]...)
log.Info("restarting: %q %q", execPath, os.Args[1:])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
err = cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
log.Fatalf("restarting:: %s", err)
}
os.Exit(0)
} else {
log.Info("Restarting: %v", os.Args)
err := syscall.Exec(curBinName, os.Args, os.Environ())
if err != nil {
log.Fatalf("syscall.Exec() failed: %s", err)
}
// Unreachable code
}
log.Info("restarting: %q %q", execPath, os.Args[1:])
err = syscall.Exec(execPath, os.Args, os.Environ())
if err != nil {
log.Fatalf("restarting: %s", err)
}
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
// Default ports.
@@ -25,7 +25,7 @@ const (
defaultPortDNS = 53
defaultPortHTTP = 80
defaultPortHTTPS = 443
defaultPortQUIC = 784
defaultPortQUIC = 853
defaultPortTLS = 853
)
@@ -58,6 +58,7 @@ func initDNSServer() (err error) {
}
conf := querylog.Config{
Anonymizer: anonymizer,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple,
@@ -67,7 +68,6 @@ func initDNSServer() (err error) {
Enabled: config.DNS.QueryLogEnabled,
FileEnabled: config.DNS.QueryLogFileEnabled,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
Anonymizer: anonymizer,
}
Context.queryLog = querylog.New(conf)
@@ -77,13 +77,36 @@ func initDNSServer() (err error) {
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil)
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
}
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
SubnetDetector: Context.subnetDetector,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
PrivateNets: privateNets,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
@@ -112,8 +135,13 @@ func initDNSServer() (err error) {
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
Context.whois = initWHOIS(&Context.clients)
if config.Clients.Sources.RDNS {
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
if config.Clients.Sources.WHOIS {
Context.whois = initWHOIS(&Context.clients)
}
Context.filters.Init()
return nil
@@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) {
return
}
if config.DNS.ResolveClients && !ip.IsLoopback() {
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}
@@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.TLSConfig = tlsConf.TLSConfig
newConf.TLSConfig.ServerName = tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
}
if tlsConf.PortDNSOverTLS != 0 {
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
}
@@ -216,7 +249,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.FilterHandler = applyAdditionalFiltering
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
newConf.ResolveClients = dnsConf.ResolveClients
newConf.ResolveClients = config.Clients.Sources.RDNS
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
@@ -301,24 +334,28 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filtering.Settings) {
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if clientAddr == nil {
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
if clientIP == nil {
return
}
setts.ClientIP = clientAddr
setts.ClientIP = clientIP
c, ok := Context.clients.Find(clientID)
if !ok {
c, ok = Context.clients.Find(clientAddr.String())
c, ok = Context.clients.Find(clientIP.String())
if !ok {
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
return
}
}
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
@@ -359,11 +396,16 @@ func startDNSServer() error {
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if config.DNS.ResolveClients && !ip.IsLoopback() {
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
if ip == nil {
continue
}
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}
@@ -413,7 +455,12 @@ func closeDNSServer() {
}
if Context.stats != nil {
Context.stats.Close()
err := Context.stats.Close()
if err != nil {
log.Debug("closing stats: %s", err)
}
// TODO(e.burkov): Find out if it's safe.
Context.stats = nil
}

View File

@@ -47,7 +47,7 @@ type homeContext struct {
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
stats stats.Interface // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
@@ -66,8 +66,6 @@ type homeContext struct {
updater *updater.Updater
subnetDetector *aghnet.SubnetDetector
// mux is our custom http.ServeMux.
mux *http.ServeMux
@@ -175,6 +173,11 @@ func setupContext(args options) {
os.Exit(0)
}
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
fatalOnError(err)
}
}
Context.mux = http.NewServeMux()
@@ -287,41 +290,35 @@ func setupConfig(args options) (err error) {
ConfName: config.getConfigFilename(),
})
if !args.noEtcHosts {
if err = setupHostsContainer(); err != nil {
return err
}
}
var arpdb aghnet.ARPDB
arpdb, err = aghnet.NewARPDB()
if err != nil {
log.Info("warning: creating arpdb: %s; using stub", err)
arpdb = aghnet.EmptyARPDB{}
if config.Clients.Sources.ARP {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts, arpdb)
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
if args.bindPort != 0 {
uc := aghalg.UniqChecker{}
addPorts(
uc,
args.bindPort,
config.BetaBindPort,
config.DNS.Port,
)
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
config.BindPort = args.bindPort
@@ -398,9 +395,6 @@ func run(args options, clientBuildFS fs.FS) {
// configure log level and output
configureLogger(args)
// Go memory hacks
memoryUsage(args)
// Print the first message after logger is configured.
log.Println(version.Full())
log.Debug("current working directory is %s", Context.workDir)
@@ -477,9 +471,6 @@ func run(args options, clientBuildFS fs.FS) {
Context.web, err = initWeb(args, clientBuildFS)
fatalOnError(err)
Context.subnetDetector, err = aghnet.NewSubnetDetector()
fatalOnError(err)
if !Context.firstRun {
err = initDNSServer()
fatalOnError(err)
@@ -529,27 +520,15 @@ func StartMods() error {
func checkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions")
if runtime.GOOS == "windows" {
// On Windows we need to have admin rights to run properly
admin, _ := aghos.HaveAdminRights()
if admin {
return
}
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
}
// We should check if AdGuard Home is able to bind to port 53
ok, err := aghnet.CanBindPort(53)
if ok {
log.Info("AdGuard Home can bind to port 53")
return
}
if errors.Is(err, os.ErrPermission) {
msg := `Permission check failed.
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
if err != nil {
if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports.
@@ -557,16 +536,17 @@ Please note, that this is crucial for a server to be able to use privileged port
You have two options:
1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
}
log.Fatal(msg)
log.Info(
"AdGuard failed to bind to port 53: %s\n\n"+
"Please note, that this is crucial for a DNS server to be able to use that port.",
err,
)
}
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
log.Info(msg)
log.Info("AdGuard Home can bind to port 53")
}
// Write PID to a file
@@ -622,17 +602,17 @@ func configureLogger(args options) {
ls.Verbose = true
}
if args.logFile != "" {
ls.LogFile = args.logFile
} else if config.LogFile != "" {
ls.LogFile = config.LogFile
ls.File = args.logFile
} else if config.File != "" {
ls.File = config.File
}
// Handle default log settings overrides
ls.LogCompress = config.LogCompress
ls.LogLocalTime = config.LogLocalTime
ls.LogMaxBackups = config.LogMaxBackups
ls.LogMaxSize = config.LogMaxSize
ls.LogMaxAge = config.LogMaxAge
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
// log.SetLevel(log.INFO) - default
if ls.Verbose {
@@ -643,27 +623,27 @@ func configureLogger(args options) {
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output.
ls.LogFile = configSyslog
ls.File = configSyslog
}
// logs are written to stdout (default)
if ls.LogFile == "" {
if ls.File == "" {
return
}
if ls.LogFile == configSyslog {
if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := aghos.ConfigureSyslog(serviceName)
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile
logFilePath := filepath.Join(Context.workDir, ls.File)
if filepath.IsAbs(ls.File) {
logFilePath = ls.File
}
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
@@ -673,11 +653,11 @@ func configureLogger(args options) {
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.LogCompress, // disabled by default
LocalTime: ls.LogLocalTime,
MaxBackups: ls.LogMaxBackups,
MaxSize: ls.LogMaxSize, // megabytes
MaxAge: ls.LogMaxAge, // days
Compress: ls.Compress, // disabled by default
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize, // megabytes
MaxAge: ls.MaxAge, // days
})
}
}

View File

@@ -13,6 +13,7 @@ import (
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
var allowedLanguages = stringutil.NewSet(
"ar",
"be",
"bg",
"cs",
@@ -50,7 +51,7 @@ var allowedLanguages = stringutil.NewSet(
"zh-tw",
)
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
@@ -58,6 +59,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
msg := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}
}
@@ -69,6 +71,7 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
msg := fmt.Sprintf("failed to read request body: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
return
}

View File

@@ -1,40 +0,0 @@
package home
import (
"os"
"runtime/debug"
"time"
"github.com/AdguardTeam/golibs/log"
)
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
// See this for the details on the performance hits & gains:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
func memoryUsage(args options) {
if args.disableMemoryOptimization {
log.Info("Memory optimization is disabled")
return
}
// Makes Go allocate heap at a slower pace
// By default we keep it at 50%
debug.SetGCPercent(50)
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
// instead of MADV_FREE on Linux when returning memory to the
// kernel. This is less efficient, but causes RSS numbers to drop
// more quickly.
_ = os.Setenv("GODEBUG", "madvdontneed=1")
// periodically call "debug.FreeOSMemory" so
// that the OS could reclaim the free memory
go func() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
log.Debug("free os memory")
debug.FreeOSMemory()
}
}()
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
uuid "github.com/satori/go.uuid"
"github.com/google/uuid"
"howett.net/plist"
)
@@ -47,9 +47,9 @@ type payloadContent struct {
PayloadType string
PayloadIdentifier string
PayloadUUID string
PayloadDisplayName string
PayloadDescription string
PayloadUUID uuid.UUID
PayloadVersion int
}
@@ -63,18 +63,14 @@ const dnsSettingsPayloadType = "com.apple.dnsSettings.managed"
type mobileConfig struct {
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadType string
PayloadUUID string
PayloadContent []*payloadContent
PayloadIdentifier uuid.UUID
PayloadUUID uuid.UUID
PayloadVersion int
PayloadRemovalDisallowed bool
}
func genUUIDv4() string {
return uuid.NewV4().String()
}
const (
dnsProtoHTTPS = "HTTPS"
dnsProtoTLS = "TLS"
@@ -104,23 +100,23 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
return nil, fmt.Errorf("bad dns protocol %q", proto)
}
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, genUUIDv4())
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, uuid.New())
data := &mobileConfig{
PayloadDescription: "Adds AdGuard Home to macOS Big Sur " +
"and iOS 14 or newer systems",
PayloadDescription: "Adds AdGuard Home to macOS Big Sur and iOS 14 or newer systems",
PayloadDisplayName: dspName,
PayloadIdentifier: genUUIDv4(),
PayloadType: "Configuration",
PayloadUUID: genUUIDv4(),
PayloadContent: []*payloadContent{{
DNSSettings: d,
PayloadType: dnsSettingsPayloadType,
PayloadIdentifier: payloadID,
PayloadUUID: genUUIDv4(),
PayloadDisplayName: dspName,
PayloadDescription: "Configures device to use AdGuard Home",
PayloadUUID: uuid.New(),
PayloadVersion: 1,
DNSSettings: d,
}},
PayloadIdentifier: uuid.New(),
PayloadUUID: uuid.New(),
PayloadVersion: 1,
PayloadRemovalDisallowed: false,
}

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
)
// options passed from command-line arguments
@@ -27,10 +28,6 @@ type options struct {
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
// disableMemoryOptimization - disables memory optimization hacks
// see memoryUsage() function for the details
disableMemoryOptimization bool
glinetMode bool // Activate GL-Inet compatibility mode
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be
@@ -178,10 +175,14 @@ var noCheckUpdateArg = arg{
}
var disableMemoryOptimizationArg = arg{
"Disable memory optimization.",
"Deprecated. Disable memory optimization.",
"no-mem-optimization", "",
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
nil, nil, func(_ options, _ string) (f effect, err error) {
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
return nil, nil
},
func(o options) []string { return nil },
}
var verboseArg = arg{
@@ -229,13 +230,19 @@ var helpArg = arg{
}
var noEtcHostsArg = arg{
description: "Do not use the OS-provided hosts.",
description: "Deprecated. Do not use the OS-provided hosts.",
longName: "no-etc-hosts",
shortName: "",
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
effect: nil,
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
effect: func(_ options, _ string) (f effect, err error) {
log.Info(
"warning: --no-etc-hosts flag is deprecated and will be removed in the future versions",
)
return nil, nil
},
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
}
var localFrontendArg = arg{

View File

@@ -101,9 +101,13 @@ func TestParseDisableUpdate(t *testing.T) {
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
// TODO(e.burkov): Remove after v0.108.0.
func TestParseDisableMemoryOptimization(t *testing.T) {
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
o, eff, err := parse("", []string{"--no-mem-optimization"})
require.NoError(t, err)
assert.Nil(t, eff)
assert.Zero(t, o)
}
func TestParseService(t *testing.T) {
@@ -127,8 +131,6 @@ func TestParseUnknown(t *testing.T) {
}
func TestSerialize(t *testing.T) {
const reportFmt = "expected %s but got %s"
testCases := []struct {
name string
opts options
@@ -173,19 +175,14 @@ func TestSerialize(t *testing.T) {
name: "glinet_mode",
opts: options{glinetMode: true},
ss: []string{"--glinet"},
}, {
name: "disable_mem_opt",
opts: options{disableMemoryOptimization: true},
ss: []string{"--no-mem-optimization"},
}, {
name: "multiple",
opts: options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
},
ss: []string{
"-c", "config",
@@ -193,18 +190,13 @@ func TestSerialize(t *testing.T) {
"-s", "run",
"--pidfile", "pid",
"--no-check-update",
"--no-mem-optimization",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := serialize(tc.opts)
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
for i, r := range result {
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
}
assert.ElementsMatch(t, tc.ss, result)
})
}
}

View File

@@ -16,18 +16,17 @@ type RDNS struct {
exchanger dnsforward.RDNSExchanger
clients *clientsContainer
// usePrivate is used to store the state of current private RDNS
// resolving settings and to react to it's changes.
// usePrivate is used to store the state of current private RDNS resolving
// settings and to react to it's changes.
usePrivate uint32
// ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan net.IP
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients
// the address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some
// time.
// address stays here while it's inside clients. After leaving clients the
// address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some time.
ipCache cache.Cache
}
@@ -125,14 +124,12 @@ func (r *RDNS) workerLoop() {
log.Debug("rdns: resolving %q: %s", ip, err)
continue
}
if host == "" {
} else if host == "" {
continue
}
// Don't handle any errors since AddHost doesn't return non-nil
// errors for now.
// Don't handle any errors since AddHost doesn't return non-nil errors
// for now.
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

Some files were not shown because too many files have changed in this diff Show More