Merge branch 'master' into 4728-cap-check
This commit is contained in:
@@ -1,32 +1,64 @@
|
||||
// Package aghalg contains common generic algorithms and data structures.
|
||||
//
|
||||
// TODO(a.garipov): Update to use type parameters in Go 1.18.
|
||||
// TODO(a.garipov): Move parts of this into golibs.
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// comparable is an alias for interface{}. Values passed as arguments of this
|
||||
// type alias must be comparable.
|
||||
// Coalesce returns the first non-zero value. It is named after function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
|
||||
// value.
|
||||
//
|
||||
// TODO(a.garipov): Remove in Go 1.18.
|
||||
type comparable = interface{}
|
||||
// T is comparable, because Go currently doesn't have a comparableWithZeroValue
|
||||
// constraint.
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [CoalesceSlice].
|
||||
func Coalesce[T comparable](values ...T) (res T) {
|
||||
var zero T
|
||||
for _, v := range values {
|
||||
if v != zero {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
// CoalesceSlice returns the first non-zero value. It is named after function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns nil.
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [Coalesce].
|
||||
func CoalesceSlice[E any, S []E](values ...S) (res S) {
|
||||
for _, v := range values {
|
||||
if v != nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UniqChecker allows validating uniqueness of comparable items.
|
||||
type UniqChecker map[comparable]int64
|
||||
//
|
||||
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
|
||||
// Consider ways of making this constraint comparable instead.
|
||||
type UniqChecker[T constraints.Ordered] map[T]int64
|
||||
|
||||
// Add adds a value to the validator. v must not be nil.
|
||||
func (uc UniqChecker) Add(elems ...comparable) {
|
||||
func (uc UniqChecker[T]) Add(elems ...T) {
|
||||
for _, e := range elems {
|
||||
uc[e]++
|
||||
}
|
||||
}
|
||||
|
||||
// Merge returns a checker containing data from both uc and other.
|
||||
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
merged = make(UniqChecker, len(uc)+len(other))
|
||||
func (uc UniqChecker[T]) Merge(other UniqChecker[T]) (merged UniqChecker[T]) {
|
||||
merged = make(UniqChecker[T], len(uc)+len(other))
|
||||
for elem, num := range uc {
|
||||
merged[elem] += num
|
||||
}
|
||||
@@ -39,10 +71,8 @@ func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
}
|
||||
|
||||
// Validate returns an error enumerating all elements that aren't unique.
|
||||
// isBefore is an optional sorting function to make the error message
|
||||
// deterministic.
|
||||
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
|
||||
var dup []comparable
|
||||
func (uc UniqChecker[T]) Validate() (err error) {
|
||||
var dup []T
|
||||
for elem, num := range uc {
|
||||
if num > 1 {
|
||||
dup = append(dup, elem)
|
||||
@@ -53,23 +83,7 @@ func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err
|
||||
return nil
|
||||
}
|
||||
|
||||
if isBefore != nil {
|
||||
sort.Slice(dup, func(i, j int) (less bool) {
|
||||
return isBefore(dup[i], dup[j])
|
||||
})
|
||||
}
|
||||
slices.Sort(dup)
|
||||
|
||||
return fmt.Errorf("duplicated values: %v", dup)
|
||||
}
|
||||
|
||||
// IntIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type int.
|
||||
func IntIsBefore(a, b comparable) (less bool) {
|
||||
return a.(int) < b.(int)
|
||||
}
|
||||
|
||||
// StringIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type string.
|
||||
func StringIsBefore(a, b comparable) (less bool) {
|
||||
return a.(string) < b.(string)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
)
|
||||
|
||||
// NullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
@@ -33,11 +35,15 @@ func (nb NullBool) String() (s string) {
|
||||
|
||||
// BoolToNullBool converts a bool into a NullBool.
|
||||
func BoolToNullBool(cond bool) (nb NullBool) {
|
||||
if cond {
|
||||
return NBTrue
|
||||
}
|
||||
return NBFalse - mathutil.BoolToNumber[NullBool](cond)
|
||||
}
|
||||
|
||||
return NBFalse
|
||||
// type check
|
||||
var _ json.Marshaler = NBNull
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for NullBool.
|
||||
func (nb NullBool) MarshalJSON() (b []byte, err error) {
|
||||
return []byte(nb.String()), nil
|
||||
}
|
||||
|
||||
// type check
|
||||
|
||||
@@ -10,6 +10,52 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_MarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
want []byte
|
||||
in aghalg.NullBool
|
||||
}{{
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
want: []byte("null"),
|
||||
in: aghalg.NBNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
want: []byte("true"),
|
||||
in: aghalg.NBTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
want: []byte("false"),
|
||||
in: aghalg.NBFalse,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.in.MarshalJSON()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
in := &struct {
|
||||
A aghalg.NullBool
|
||||
}{
|
||||
A: aghalg.NBTrue,
|
||||
}
|
||||
|
||||
got, err := json.Marshal(in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte(`{"A":true}`), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
33
internal/aghchan/aghchan.go
Normal file
33
internal/aghchan/aghchan.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Package aghchan contains channel utilities.
|
||||
package aghchan
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Receive returns an error if it cannot receive a value form c before timeout
|
||||
// runs out.
|
||||
func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) {
|
||||
var zero T
|
||||
timeoutCh := time.After(timeout)
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
// TODO(a.garipov): Consider implementing [errors.Aser] for
|
||||
// os.ErrTimeout.
|
||||
return zero, false, fmt.Errorf("did not receive after %s", timeout)
|
||||
case v, ok = <-c:
|
||||
return v, ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
// MustReceive panics if it cannot receive a value form c before timeout runs
|
||||
// out.
|
||||
func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) {
|
||||
v, ok, err := Receive(c, timeout)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return v, ok
|
||||
}
|
||||
@@ -2,13 +2,27 @@
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
const (
|
||||
SchemeHTTP = "http"
|
||||
SchemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Get rid of it.
|
||||
type RegisterFunc func(method, url string, handler http.HandlerFunc)
|
||||
|
||||
// OK responds with word OK.
|
||||
func OK(w http.ResponseWriter) {
|
||||
if _, err := io.WriteString(w, "OK\n"); err != nil {
|
||||
@@ -17,8 +31,52 @@ func OK(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s: %s", r.Method, r.URL, text)
|
||||
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// UserAgent returns the ID of the service as a User-Agent string. It can also
|
||||
// be used as the value of the Server HTTP header.
|
||||
func UserAgent() (ua string) {
|
||||
return fmt.Sprintf("AdGuardHome/%s", version.Version())
|
||||
}
|
||||
|
||||
// textPlainDeprMsg is the message returned to API users when they try to use
|
||||
// an API that used to accept "text/plain" but doesn't anymore.
|
||||
const textPlainDeprMsg = `using this api with the text/plain content-type is deprecated; ` +
|
||||
`use application/json`
|
||||
|
||||
// WriteTextPlainDeprecated responds to the request with a message about
|
||||
// deprecation and removal of a plain-text API if the request is made with the
|
||||
// "text/plain" content-type.
|
||||
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
|
||||
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
|
||||
return false
|
||||
}
|
||||
|
||||
Error(r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// WriteJSONResponse sets the content-type header in w.Header() to
|
||||
// "application/json", writes a header with a "200 OK" status, encodes resp to
|
||||
// w, calls [Error] on any returned error, and returns it as well.
|
||||
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) {
|
||||
return WriteJSONResponseCode(w, r, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to
|
||||
// redefine the status code.
|
||||
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
|
||||
w.WriteHeader(code)
|
||||
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
25
internal/aghhttp/header.go
Normal file
25
internal/aghhttp/header.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package aghhttp
|
||||
|
||||
// HTTP Headers
|
||||
|
||||
// HTTP header name constants.
|
||||
//
|
||||
// TODO(a.garipov): Remove unused.
|
||||
const (
|
||||
HdrNameAcceptEncoding = "Accept-Encoding"
|
||||
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
|
||||
HdrNameAltSvc = "Alt-Svc"
|
||||
HdrNameContentEncoding = "Content-Encoding"
|
||||
HdrNameContentType = "Content-Type"
|
||||
HdrNameOrigin = "Origin"
|
||||
HdrNameServer = "Server"
|
||||
HdrNameTrailer = "Trailer"
|
||||
HdrNameUserAgent = "User-Agent"
|
||||
HdrNameVary = "Vary"
|
||||
)
|
||||
|
||||
// HTTP header value constants.
|
||||
const (
|
||||
HdrValApplicationJSON = "application/json"
|
||||
HdrValTextPlain = "text/plain"
|
||||
)
|
||||
@@ -4,6 +4,9 @@ package aghio
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
)
|
||||
|
||||
// LimitReachedError records the limit and the operation that caused it.
|
||||
@@ -11,22 +14,22 @@ type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements the error interface for LimitReachedError.
|
||||
// Error implements the [error] interface for *LimitReachedError.
|
||||
//
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
|
||||
}
|
||||
|
||||
// limitedReader is a wrapper for io.Reader with limited reader and dealing with
|
||||
// errors package.
|
||||
// limitedReader is a wrapper for [io.Reader] limiting the input and dealing
|
||||
// with errors package.
|
||||
type limitedReader struct {
|
||||
r io.Reader
|
||||
limit int64
|
||||
n int64
|
||||
}
|
||||
|
||||
// Read implements Reader interface.
|
||||
// Read implements the [io.Reader] interface.
|
||||
func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
if lr.n == 0 {
|
||||
return 0, &LimitReachedError{
|
||||
@@ -34,9 +37,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if int64(len(p)) > lr.n {
|
||||
p = p[:lr.n]
|
||||
}
|
||||
p = p[:mathutil.Min(lr.n, int64(len(p)))]
|
||||
|
||||
n, err = lr.r.Read(p)
|
||||
lr.n -= int64(n)
|
||||
@@ -48,7 +49,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
// n bytes read.
|
||||
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
|
||||
if n < 0 {
|
||||
return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n)
|
||||
return nil, errors.Error("limit must be non-negative")
|
||||
}
|
||||
|
||||
return &limitedReader{
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestLimitReader(t *testing.T) {
|
||||
name: "zero",
|
||||
n: 0,
|
||||
}, {
|
||||
wantErrMsg: "aghio: invalid n in LimitReader: -1",
|
||||
wantErrMsg: "limit must be non-negative",
|
||||
name: "negative",
|
||||
n: -1,
|
||||
}}
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// ARPDB: The Network Neighborhood Database
|
||||
@@ -54,7 +55,7 @@ type Neighbor struct {
|
||||
Name string
|
||||
|
||||
// IP contains either IPv4 or IPv6.
|
||||
IP net.IP
|
||||
IP netip.Addr
|
||||
|
||||
// MAC contains the hardware address.
|
||||
MAC net.HardwareAddr
|
||||
@@ -64,8 +65,8 @@ type Neighbor struct {
|
||||
func (n Neighbor) Clone() (clone Neighbor) {
|
||||
return Neighbor{
|
||||
Name: n.Name,
|
||||
IP: netutil.CloneIP(n.IP),
|
||||
MAC: netutil.CloneMAC(n.MAC),
|
||||
IP: n.IP,
|
||||
MAC: slices.Clone(n.MAC),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
//go:build darwin || freebsd
|
||||
// +build darwin freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -33,8 +33,7 @@ func newARPDB() (arp *cmdARPDB) {
|
||||
// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD.
|
||||
// The expected input format:
|
||||
//
|
||||
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
|
||||
//
|
||||
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
@@ -49,22 +48,28 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
if ipStr := fields[1]; len(ipStr) < 2 {
|
||||
continue
|
||||
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
|
||||
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
hwStr := fields[3]
|
||||
if mac, err := net.ParseMAC(hwStr); err != nil {
|
||||
mac, err := net.ParseMAC(hwStr)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.MAC = mac
|
||||
}
|
||||
|
||||
host := fields[0]
|
||||
if err := netutil.ValidateDomainName(host); err != nil {
|
||||
log.Debug("parsing arp output: %s", err)
|
||||
err = netutil.ValidateDomainName(host)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: host: %s", err)
|
||||
} else {
|
||||
n.Name = host
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build darwin || freebsd
|
||||
// +build darwin freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const arpAOutput = `
|
||||
@@ -18,14 +18,14 @@ hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [
|
||||
|
||||
var wantNeighs = []Neighbor{{
|
||||
Name: "hostname.one",
|
||||
IP: net.IPv4(192, 168, 1, 2),
|
||||
IP: netip.MustParseAddr("192.168.1.2"),
|
||||
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
|
||||
}, {
|
||||
Name: "hostname.two",
|
||||
IP: net.ParseIP("::ffff:ffff"),
|
||||
IP: netip.MustParseAddr("::ffff:ffff"),
|
||||
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
|
||||
}, {
|
||||
Name: "",
|
||||
IP: net.ParseIP("::1234"),
|
||||
IP: netip.MustParseAddr("::1234"),
|
||||
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
}}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -8,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -95,7 +95,8 @@ func (arp *fsysARPDB) Refresh() (err error) {
|
||||
}
|
||||
|
||||
n := Neighbor{}
|
||||
if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() {
|
||||
n.IP, err = netip.ParseAddr(fields[0])
|
||||
if err != nil || n.IP.IsUnspecified() {
|
||||
continue
|
||||
} else if n.MAC, err = net.ParseMAC(fields[3]); err != nil {
|
||||
continue
|
||||
@@ -117,9 +118,8 @@ func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
|
||||
// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The
|
||||
// expected input format:
|
||||
//
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
|
||||
//
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
|
||||
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
if !sc.Scan() {
|
||||
// Skip the header.
|
||||
@@ -137,15 +137,19 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
n := Neighbor{}
|
||||
|
||||
if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() {
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil || n.IP.IsUnspecified() {
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
hwStr := fields[3]
|
||||
if mac, err := net.ParseMAC(hwStr); err != nil {
|
||||
log.Debug("parsing arp output: %s", err)
|
||||
mac, err := net.ParseMAC(hwStr)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
@@ -161,8 +165,7 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
// parseArpA parses the output of the "arp -a -n" command on Linux. The
|
||||
// expected input format:
|
||||
//
|
||||
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
|
||||
//
|
||||
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
@@ -177,7 +180,9 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
if ipStr := fields[1]; len(ipStr) < 2 {
|
||||
continue
|
||||
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
|
||||
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
@@ -185,7 +190,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
hwStr := fields[3]
|
||||
if mac, err := net.ParseMAC(hwStr); err != nil {
|
||||
log.Debug("parsing arp output: %s", err)
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
@@ -194,7 +199,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
host := fields[0]
|
||||
if verr := netutil.ValidateDomainName(host); verr != nil {
|
||||
log.Debug("parsing arp output: %s", verr)
|
||||
log.Debug("arpdb: parsing arp output: host: %s", verr)
|
||||
} else {
|
||||
n.Name = host
|
||||
}
|
||||
@@ -208,8 +213,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
// parseIPNeigh parses the output of the "ip neigh" command on Linux. The
|
||||
// expected input format:
|
||||
//
|
||||
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
|
||||
//
|
||||
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
|
||||
func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
@@ -222,14 +226,18 @@ func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
n := Neighbor{}
|
||||
|
||||
if ip := net.ParseIP(fields[0]); ip == nil {
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
if mac, err := net.ParseMAC(fields[4]); err != nil {
|
||||
log.Debug("parsing arp output: %s", err)
|
||||
mac, err := net.ParseMAC(fields[4])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
@@ -34,10 +34,10 @@ const ipNeighOutput = `
|
||||
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
|
||||
|
||||
var wantNeighs = []Neighbor{{
|
||||
IP: net.IPv4(192, 168, 1, 2),
|
||||
IP: netip.MustParseAddr("192.168.1.2"),
|
||||
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
|
||||
}, {
|
||||
IP: net.ParseIP("::ffff:ffff"),
|
||||
IP: netip.MustParseAddr("::ffff:ffff"),
|
||||
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
|
||||
}}
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -32,9 +32,8 @@ func newARPDB() (arp *cmdARPDB) {
|
||||
// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The
|
||||
// expected input format:
|
||||
//
|
||||
// Host Ethernet Address Netif Expire Flags
|
||||
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
|
||||
//
|
||||
// Host Ethernet Address Netif Expire Flags
|
||||
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
// Skip the header.
|
||||
if !sc.Scan() {
|
||||
@@ -52,14 +51,18 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
n := Neighbor{}
|
||||
|
||||
if ip := net.ParseIP(fields[0]); ip == nil {
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
if mac, err := net.ParseMAC(fields[1]); err != nil {
|
||||
log.Debug("parsing arp output: %s", err)
|
||||
mac, err := net.ParseMAC(fields[1])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const arpAOutput = `
|
||||
@@ -16,9 +16,9 @@ Host Ethernet Address Netif Expire Flags
|
||||
`
|
||||
|
||||
var wantNeighs = []Neighbor{{
|
||||
IP: net.IPv4(192, 168, 1, 2),
|
||||
IP: netip.MustParseAddr("192.168.1.2"),
|
||||
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
|
||||
}, {
|
||||
IP: net.ParseIP("::ffff:ffff"),
|
||||
IP: netip.MustParseAddr("::ffff:ffff"),
|
||||
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
|
||||
}}
|
||||
|
||||
@@ -2,6 +2,7 @@ package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -35,7 +36,7 @@ func (arp *TestARPDB) Neighbors() (ns []Neighbor) {
|
||||
}
|
||||
|
||||
func TestARPDBS(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}
|
||||
|
||||
succRefrCount, failRefrCount := 0, 0
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -25,12 +25,10 @@ func newARPDB() (arp *cmdARPDB) {
|
||||
// parseArpA parses the output of the "arp /a" command on Windows. The expected
|
||||
// input format (the first line is empty):
|
||||
//
|
||||
//
|
||||
// Interface: 192.168.56.16 --- 0x7
|
||||
// Internet Address Physical Address Type
|
||||
// 192.168.56.1 0a-00-27-00-00-00 dynamic
|
||||
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
|
||||
//
|
||||
// Interface: 192.168.56.16 --- 0x7
|
||||
// Internet Address Physical Address Type
|
||||
// 192.168.56.1 0a-00-27-00-00-00 dynamic
|
||||
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
@@ -46,13 +44,15 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
|
||||
n := Neighbor{}
|
||||
|
||||
if ip := net.ParseIP(fields[0]); ip == nil {
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
if mac, err := net.ParseMAC(fields[1]); err != nil {
|
||||
mac, err := net.ParseMAC(fields[1])
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
n.MAC = mac
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const arpAOutput = `
|
||||
@@ -15,9 +15,9 @@ Interface: 192.168.1.1 --- 0x7
|
||||
::ffff:ffff ef-cd-ab-ef-cd-ab static`
|
||||
|
||||
var wantNeighs = []Neighbor{{
|
||||
IP: net.IPv4(192, 168, 1, 2),
|
||||
IP: netip.MustParseAddr("192.168.1.2"),
|
||||
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
|
||||
}, {
|
||||
IP: net.ParseIP("::ffff:ffff"),
|
||||
IP: netip.MustParseAddr("::ffff:ffff"),
|
||||
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
|
||||
}}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -7,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -39,48 +39,44 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
|
||||
}
|
||||
|
||||
// ifaceIPv4Subnet returns the first suitable IPv4 subnetwork iface has.
|
||||
func ifaceIPv4Subnet(iface *net.Interface) (subnet *net.IPNet, err error) {
|
||||
func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) {
|
||||
var addrs []net.Addr
|
||||
if addrs, err = iface.Addrs(); err != nil {
|
||||
return nil, err
|
||||
return netip.Prefix{}, err
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
var ip net.IP
|
||||
var maskLen int
|
||||
switch a := a.(type) {
|
||||
case *net.IPAddr:
|
||||
subnet = &net.IPNet{
|
||||
IP: a.IP,
|
||||
Mask: a.IP.DefaultMask(),
|
||||
}
|
||||
ip = a.IP
|
||||
maskLen, _ = ip.DefaultMask().Size()
|
||||
case *net.IPNet:
|
||||
subnet = a
|
||||
ip = a.IP
|
||||
maskLen, _ = a.Mask.Size()
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if ip4 := subnet.IP.To4(); ip4 != nil {
|
||||
subnet.IP = ip4
|
||||
|
||||
return subnet, nil
|
||||
if ip = ip.To4(); ip != nil {
|
||||
return netip.PrefixFrom(netip.AddrFrom4(*(*[4]byte)(ip)), maskLen), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
|
||||
return netip.Prefix{}, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
|
||||
}
|
||||
|
||||
// checkOtherDHCPv4 sends a DHCP request to the specified network interface, and
|
||||
// waits for a response for a period defined by defaultDiscoverTime.
|
||||
func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
|
||||
var subnet *net.IPNet
|
||||
var subnet netip.Prefix
|
||||
if subnet, err = ifaceIPv4Subnet(iface); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Resolve broadcast addr.
|
||||
dst := netutil.IPPort{
|
||||
IP: BroadcastFromIPNet(subnet),
|
||||
Port: 67,
|
||||
}.String()
|
||||
dst := netip.AddrPortFrom(BroadcastFromPref(subnet), 67).String()
|
||||
var dstAddr *net.UDPAddr
|
||||
if dstAddr, err = net.ResolveUDPAddr("udp4", dst); err != nil {
|
||||
return false, fmt.Errorf("couldn't resolve UDP address %s: %w", dst, err)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -45,11 +45,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
// GenerateHostname generates the hostname from ip. In case of using IPv4 the
|
||||
// result should be like:
|
||||
//
|
||||
// 192-168-10-1
|
||||
// 192-168-10-1
|
||||
//
|
||||
// In case of using IPv6, the result is like:
|
||||
//
|
||||
// ff80-f076-0000-0000-0000-0000-0000-0010
|
||||
// ff80-f076-0000-0000-0000-0000-0000-0010
|
||||
//
|
||||
// ip must be either an IPv4 or an IPv6.
|
||||
func GenerateHostname(ip net.IP) (hostname string) {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// DefaultHostsPaths returns the slice of paths default for the operating system
|
||||
@@ -55,7 +56,7 @@ func (rm *requestMatcher) MatchRequest(
|
||||
) (res *urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
log.Debug("%s: handling the request", hostsContainerPref)
|
||||
log.Debug("%s: handling the request for %s", hostsContainerPref, req.Hostname)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
@@ -70,8 +71,7 @@ func (rm *requestMatcher) MatchRequest(
|
||||
// rule or an empty string if the last doesn't exist. The returned rules are in
|
||||
// a processed format like:
|
||||
//
|
||||
// ip host1 host2 ...
|
||||
//
|
||||
// ip host1 host2 ...
|
||||
func (rm *requestMatcher) Translate(rule string) (hostRule string) {
|
||||
rm.stateLock.RLock()
|
||||
defer rm.stateLock.RUnlock()
|
||||
@@ -107,10 +107,10 @@ type HostsContainer struct {
|
||||
done chan struct{}
|
||||
|
||||
// updates is the channel for receiving updated hosts.
|
||||
updates chan *netutil.IPMap
|
||||
updates chan HostsRecords
|
||||
|
||||
// last is the set of hosts that was cached within last detected change.
|
||||
last *netutil.IPMap
|
||||
last HostsRecords
|
||||
|
||||
// fsys is the working file system to read hosts files from.
|
||||
fsys fs.FS
|
||||
@@ -125,6 +125,27 @@ type HostsContainer struct {
|
||||
listID int
|
||||
}
|
||||
|
||||
// HostsRecords is a mapping of an IP address to its hosts data.
|
||||
type HostsRecords map[netip.Addr]*HostsRecord
|
||||
|
||||
// HostsRecord represents a single hosts file record.
|
||||
type HostsRecord struct {
|
||||
Aliases *stringutil.Set
|
||||
Canonical string
|
||||
}
|
||||
|
||||
// equal returns true if all fields of rec are equal to field in other or they
|
||||
// both are nil.
|
||||
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
|
||||
if rec == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
|
||||
}
|
||||
|
||||
// ErrNoHostsPaths is returned when there are no valid paths to watch passed to
|
||||
// the HostsContainer.
|
||||
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
|
||||
@@ -159,7 +180,7 @@ func NewHostsContainer(
|
||||
},
|
||||
listID: listID,
|
||||
done: make(chan struct{}, 1),
|
||||
updates: make(chan *netutil.IPMap, 1),
|
||||
updates: make(chan HostsRecords, 1),
|
||||
fsys: fsys,
|
||||
w: w,
|
||||
patterns: patterns,
|
||||
@@ -197,9 +218,8 @@ func (hc *HostsContainer) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent. The receivable
|
||||
// map's values are guaranteed to be of type of *stringutil.Set.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||
// Upd returns the channel into which the updates are sent.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan HostsRecords) {
|
||||
return hc.updates
|
||||
}
|
||||
|
||||
@@ -265,7 +285,7 @@ type hostsParser struct {
|
||||
|
||||
// table stores only the unique IP-hostname pairs. It's also sent to the
|
||||
// updates channel afterwards.
|
||||
table *netutil.IPMap
|
||||
table HostsRecords
|
||||
}
|
||||
|
||||
// newHostsParser creates a new *hostsParser with buffers of size taken from the
|
||||
@@ -274,7 +294,7 @@ func (hc *HostsContainer) newHostsParser() (hp *hostsParser) {
|
||||
return &hostsParser{
|
||||
rulesBuilder: &strings.Builder{},
|
||||
translations: map[string]string{},
|
||||
table: netutil.NewIPMap(hc.last.Len()),
|
||||
table: make(HostsRecords, len(hc.last)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,25 +306,26 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
ip, hosts := hp.parseLine(s.Text())
|
||||
if ip == nil || len(hosts) == 0 {
|
||||
if ip == (netip.Addr{}) || len(hosts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hp.addPairs(ip, hosts)
|
||||
hp.addRecord(ip, hosts)
|
||||
}
|
||||
|
||||
return nil, true, s.Err()
|
||||
}
|
||||
|
||||
// parseLine parses the line having the hosts syntax ignoring invalid ones.
|
||||
func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
return nil, nil
|
||||
return netip.Addr{}, nil
|
||||
}
|
||||
|
||||
if ip = net.ParseIP(fields[0]); ip == nil {
|
||||
return nil, nil
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil
|
||||
}
|
||||
|
||||
for _, f := range fields[1:] {
|
||||
@@ -322,7 +343,7 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
|
||||
//
|
||||
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
|
||||
err := netutil.ValidateDomainName(f)
|
||||
err = netutil.ValidateDomainName(f)
|
||||
if err != nil {
|
||||
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
|
||||
|
||||
@@ -335,43 +356,47 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
return ip, hosts
|
||||
}
|
||||
|
||||
// addPair puts the pair of ip and host to the rules builder if needed. For
|
||||
// each ip the first member of hosts will become the main one.
|
||||
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
|
||||
v, ok := hp.table.Get(ip)
|
||||
// addRecord puts the record for the IP address to the rules builder if needed.
|
||||
// The first host is considered to be the canonical name for the IP address.
|
||||
// hosts must have at least one name.
|
||||
func (hp *hostsParser) addRecord(ip netip.Addr, hosts []string) {
|
||||
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
|
||||
rec, ok := hp.table[ip]
|
||||
if !ok {
|
||||
// This ip is added at the first time.
|
||||
v = stringutil.NewSet()
|
||||
hp.table.Set(ip, v)
|
||||
rec = &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
}
|
||||
|
||||
rec.Canonical, hosts = hosts[0], hosts[1:]
|
||||
hp.addRules(ip, rec.Canonical, line)
|
||||
hp.table[ip] = rec
|
||||
}
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
if !ok {
|
||||
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
for _, h := range hosts {
|
||||
if set.Has(h) {
|
||||
for _, host := range hosts {
|
||||
if rec.Canonical == host || rec.Aliases.Has(host) {
|
||||
continue
|
||||
}
|
||||
|
||||
set.Add(h)
|
||||
rec.Aliases.Add(host)
|
||||
|
||||
rule, rulePtr := hp.writeRules(h, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = processed, processed
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
|
||||
hp.addRules(ip, host, line)
|
||||
}
|
||||
}
|
||||
|
||||
// addRules adds rules and rule translations for the line.
|
||||
func (hp *hostsParser) addRules(ip netip.Addr, host, line string) {
|
||||
rule, rulePtr := hp.writeRules(host, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = line, line
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
|
||||
// pair into internal builders.
|
||||
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
|
||||
arpa, err := netutil.IPToReversedAddr(ip)
|
||||
func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr string) {
|
||||
// TODO(a.garipov): Add a netip.Addr version to netutil.
|
||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
@@ -389,7 +414,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
|
||||
var qtype string
|
||||
// The validation of the IP address has been performed earlier so it is
|
||||
// guaranteed to be either an IPv4 or an IPv6.
|
||||
if ip.To4() != nil {
|
||||
if ip.Is4() {
|
||||
qtype = "A"
|
||||
} else {
|
||||
qtype = "AAAA"
|
||||
@@ -416,37 +441,8 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
|
||||
return rule, rulePtr
|
||||
}
|
||||
|
||||
// equalSet returns true if the internal hosts table just parsed equals target.
|
||||
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
if target == nil {
|
||||
// hp.table shouldn't appear nil since it's initialized on each refresh.
|
||||
return target == hp.table
|
||||
}
|
||||
|
||||
if hp.table.Len() != target.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
|
||||
// ok is set to true if the target doesn't contain ip or if the
|
||||
// appropriate hosts set isn't equal to the checked one.
|
||||
if a, hasIP := target.Get(ip); !hasIP {
|
||||
ok = true
|
||||
} else if hosts, aok := a.(*stringutil.Set); aok {
|
||||
ok = !hosts.Equal(b.(*stringutil.Set))
|
||||
}
|
||||
|
||||
// Continue only if maps has no discrepancies.
|
||||
return !ok
|
||||
})
|
||||
|
||||
// Return true if every value from the IP map has no discrepancies with the
|
||||
// appropriate one from the target.
|
||||
return !ok
|
||||
}
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) {
|
||||
func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
|
||||
log.Debug("%s: sending upd", hostsContainerPref)
|
||||
|
||||
upd := hp.table
|
||||
@@ -484,14 +480,15 @@ func (hc *HostsContainer) refresh() (err error) {
|
||||
return fmt.Errorf("refreshing : %w", err)
|
||||
}
|
||||
|
||||
if hp.equalSet(hc.last) {
|
||||
// hc.last is nil on the first refresh, so let that one through.
|
||||
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
|
||||
log.Debug("%s: no changes detected", hostsContainerPref)
|
||||
|
||||
return nil
|
||||
}
|
||||
defer hp.sendUpd(hc.updates)
|
||||
|
||||
hc.last = hp.table.ShallowClone()
|
||||
hc.last = maps.Clone(hp.table)
|
||||
|
||||
var rulesStrg *filterlist.RuleStorage
|
||||
if rulesStrg, err = hp.newStrg(hc.listID); err != nil {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !(windows || linux)
|
||||
// +build !windows,!linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package aghnet
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -10,8 +11,10 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -134,7 +137,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
func TestHostsContainer_refresh(t *testing.T) {
|
||||
// TODO(e.burkov): Test the case with no actual updates.
|
||||
|
||||
ip := net.IP{127, 0, 0, 1}
|
||||
ip := netutil.IPv4Localhost()
|
||||
ipStr := ip.String()
|
||||
|
||||
testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}}
|
||||
@@ -159,31 +162,37 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
|
||||
upd, ok := <-hc.Upd()
|
||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
||||
t.Helper()
|
||||
|
||||
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
assert.Len(t, upd, 1)
|
||||
|
||||
v, ok := upd.Get(ip)
|
||||
rec, ok := upd[ip]
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rec)
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.True(t, set.Equal(wantHosts))
|
||||
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, stringutil.NewSet("hostname"))
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("second_refresh", func(t *testing.T) {
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
|
||||
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet("alias"),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("double_refresh", func(t *testing.T) {
|
||||
@@ -363,10 +372,15 @@ func TestHostsContainer(t *testing.T) {
|
||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||
|
||||
testCases := []struct {
|
||||
want []*rules.DNSRewrite
|
||||
name string
|
||||
req *urlfilter.DNSRequest
|
||||
name string
|
||||
want []*rules.DNSRewrite
|
||||
}{{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "simple",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 1),
|
||||
@@ -376,27 +390,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::1"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "simple",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "hello_alias",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
@@ -406,26 +405,41 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "other_line_alias",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world.again",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "other_line_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_subdomain",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_alias_subdomain",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "hello_alias_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "lots_of_aliases",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -435,37 +449,37 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::2"),
|
||||
}},
|
||||
name: "lots_of_aliases",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
name: "reverse",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypePTR,
|
||||
Value: "simplehost.",
|
||||
}},
|
||||
name: "reverse",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "non-existing",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "nonexisting",
|
||||
Hostname: "nonexistent.example",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "non-existing",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
want: nil,
|
||||
name: "bad_type",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypeSRV,
|
||||
},
|
||||
name: "bad_type",
|
||||
want: nil,
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "issue_4216_4_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -475,12 +489,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::42"),
|
||||
}},
|
||||
name: "issue_4216_4_6",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "issue_4216_4",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -490,12 +504,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(1, 3, 5, 7),
|
||||
}},
|
||||
name: "issue_4216_4",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
name: "issue_4216_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
@@ -505,11 +519,6 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::31"),
|
||||
}},
|
||||
name: "issue_4216_6",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
}}
|
||||
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
@@ -551,13 +560,13 @@ func TestHostsContainer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
ip := net.IP{127, 0, 0, 1}
|
||||
ip := netutil.IPv4Localhost()
|
||||
ipStr := ip.String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
line string
|
||||
wantIP net.IP
|
||||
wantIP netip.Addr
|
||||
wantHosts []string
|
||||
}{{
|
||||
name: "simple",
|
||||
@@ -572,7 +581,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
}, {
|
||||
name: "invalid_line",
|
||||
line: ipStr,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "invalid_line_hostname",
|
||||
@@ -587,7 +596,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
}, {
|
||||
name: "whole_comment",
|
||||
line: `# ` + ipStr + ` hostname`,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "partial_comment",
|
||||
@@ -597,7 +606,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
}, {
|
||||
name: "empty",
|
||||
line: ``,
|
||||
wantIP: nil,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}}
|
||||
|
||||
@@ -605,7 +614,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
hp := hostsParser{}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, hosts := hp.parseLine(tc.line)
|
||||
assert.True(t, tc.wantIP.Equal(got))
|
||||
assert.Equal(t, tc.wantIP, got)
|
||||
assert.Equal(t, tc.wantHosts, hosts)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -16,7 +15,7 @@ import (
|
||||
func defaultHostsPaths() (paths []string) {
|
||||
sysDir, err := windows.GetSystemDirectory()
|
||||
if err != nil {
|
||||
log.Error("getting system directory: %s", err)
|
||||
log.Error("aghnet: getting system directory: %s", err)
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ type IpsetManager interface {
|
||||
//
|
||||
// The syntax of the ipsetConf is:
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
//
|
||||
// If ipsetConf is empty, msg and err are nil. The error is of type
|
||||
// *aghos.UnsupportedError if the OS is not supported.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -19,27 +18,18 @@ import (
|
||||
|
||||
// How to test on a real Linux machine:
|
||||
//
|
||||
// 1. Run:
|
||||
// 1. Run "sudo ipset create example_set hash:ip family ipv4".
|
||||
//
|
||||
// sudo ipset create example_set hash:ip family ipv4
|
||||
// 2. Run "sudo ipset list example_set". The Members field should be empty.
|
||||
//
|
||||
// 2. Run:
|
||||
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
|
||||
//
|
||||
// sudo ipset list example_set
|
||||
// 4. Start AdGuardHome.
|
||||
//
|
||||
// The Members field should be empty.
|
||||
// 5. Make requests to example.com and its subdomains.
|
||||
//
|
||||
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
|
||||
//
|
||||
// 4. Start AdGuardHome.
|
||||
//
|
||||
// 5. Make requests to example.com and its subdomains.
|
||||
//
|
||||
// 6. Run:
|
||||
//
|
||||
// sudo ipset list example_set
|
||||
//
|
||||
// The Members field should contain the resolved IP addresses.
|
||||
// 6. Run "sudo ipset list example_set". The Members field should contain the
|
||||
// resolved IP addresses.
|
||||
|
||||
// newIpsetMgr returns a new Linux ipset manager.
|
||||
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// Variables and functions to substitute in tests.
|
||||
@@ -47,26 +47,31 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
|
||||
//
|
||||
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
|
||||
// way since not every machine has the software installed.
|
||||
func GatewayIP(ifaceName string) (ip net.IP) {
|
||||
func GatewayIP(ifaceName string) (ip netip.Addr) {
|
||||
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
|
||||
if err != nil {
|
||||
log.Debug("%s", err)
|
||||
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
} else if code != 0 {
|
||||
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
|
||||
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
fields := bytes.Fields(out)
|
||||
// The meaningful "ip route" command output should contain the word
|
||||
// "default" at first field and default gateway IP address at third field.
|
||||
if len(fields) < 3 || string(fields[0]) != "default" {
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
return net.ParseIP(string(fields[2]))
|
||||
ip, err = netip.ParseAddr(string(fields[2]))
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
// CanBindPrivilegedPorts checks if current process can bind to privileged
|
||||
@@ -78,9 +83,9 @@ func CanBindPrivilegedPorts() (can bool, err error) {
|
||||
// NetInterface represents an entry of network interfaces map.
|
||||
type NetInterface struct {
|
||||
// Addresses are the network interface addresses.
|
||||
Addresses []net.IP `json:"ip_addresses,omitempty"`
|
||||
Addresses []netip.Addr `json:"ip_addresses,omitempty"`
|
||||
// Subnets are the IP networks for this network interface.
|
||||
Subnets []*net.IPNet `json:"-"`
|
||||
Subnets []netip.Prefix `json:"-"`
|
||||
Name string `json:"name"`
|
||||
HardwareAddr net.HardwareAddr `json:"hardware_address"`
|
||||
Flags net.Flags `json:"flags"`
|
||||
@@ -101,63 +106,88 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
|
||||
})
|
||||
}
|
||||
|
||||
func NetInterfaceFrom(iface *net.Interface) (niface *NetInterface, err error) {
|
||||
niface = &NetInterface{
|
||||
Name: iface.Name,
|
||||
HardwareAddr: iface.HardwareAddr,
|
||||
Flags: iface.Flags,
|
||||
MTU: iface.MTU,
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
|
||||
}
|
||||
|
||||
// Collect network interface addresses.
|
||||
for _, addr := range addrs {
|
||||
n, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// Should be *net.IPNet, this is weird.
|
||||
return nil, fmt.Errorf("expected %[2]s to be %[1]T, got %[2]T", n, addr)
|
||||
} else if ip4 := n.IP.To4(); ip4 != nil {
|
||||
n.IP = ip4
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(n.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("bad address %s", n.IP)
|
||||
}
|
||||
|
||||
ip = ip.Unmap()
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
// Ignore link-local IPv4.
|
||||
if ip.Is4() {
|
||||
continue
|
||||
}
|
||||
|
||||
ip = ip.WithZone(iface.Name)
|
||||
}
|
||||
|
||||
ones, _ := n.Mask.Size()
|
||||
p := netip.PrefixFrom(ip, ones)
|
||||
|
||||
niface.Addresses = append(niface.Addresses, ip)
|
||||
niface.Subnets = append(niface.Subnets, p)
|
||||
}
|
||||
|
||||
return niface, nil
|
||||
}
|
||||
|
||||
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
|
||||
// WEB only we do not return link-local addresses here.
|
||||
//
|
||||
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
|
||||
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
|
||||
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
|
||||
func GetValidNetInterfacesForWeb() (nifaces []*NetInterface, err error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
|
||||
return nil, fmt.Errorf("getting interfaces: %w", err)
|
||||
} else if len(ifaces) == 0 {
|
||||
return nil, errors.Error("couldn't find any legible interface")
|
||||
return nil, errors.Error("no legible interfaces")
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
var addrs []net.Addr
|
||||
addrs, err = iface.Addrs()
|
||||
for i := range ifaces {
|
||||
var niface *NetInterface
|
||||
niface, err = NetInterfaceFrom(&ifaces[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
|
||||
}
|
||||
|
||||
netIface := &NetInterface{
|
||||
MTU: iface.MTU,
|
||||
Name: iface.Name,
|
||||
HardwareAddr: iface.HardwareAddr,
|
||||
Flags: iface.Flags,
|
||||
}
|
||||
|
||||
// Collect network interface addresses.
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// Should be net.IPNet, this is weird.
|
||||
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
|
||||
}
|
||||
|
||||
// Ignore link-local.
|
||||
if ipNet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
|
||||
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
|
||||
netIface.Subnets = append(netIface.Subnets, ipNet)
|
||||
}
|
||||
|
||||
// Discard interfaces with no addresses.
|
||||
if len(netIface.Addresses) != 0 {
|
||||
netIfaces = append(netIfaces, netIface)
|
||||
return nil, err
|
||||
} else if len(niface.Addresses) != 0 {
|
||||
// Discard interfaces with no addresses.
|
||||
nifaces = append(nifaces, niface)
|
||||
}
|
||||
}
|
||||
|
||||
return netIfaces, nil
|
||||
return nifaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceByIP returns the name of interface containing provided ip.
|
||||
// InterfaceByIP returns the name of the interface bound to ip.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
|
||||
func GetInterfaceByIP(ip net.IP) string {
|
||||
// TODO(a.garipov, e.burkov): This function is technically incorrect, since one
|
||||
// IP address can be shared by multiple interfaces in some configurations.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func InterfaceByIP(ip netip.Addr) (ifaceName string) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -165,7 +195,7 @@ func GetInterfaceByIP(ip net.IP) string {
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
if ip.Equal(addr) {
|
||||
if ip == addr {
|
||||
return iface.Name
|
||||
}
|
||||
}
|
||||
@@ -174,15 +204,16 @@ func GetInterfaceByIP(ip net.IP) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
|
||||
// GetSubnet returns the subnet corresponding to the interface of zero prefix if
|
||||
// the search fails.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
|
||||
func GetSubnet(ifaceName string) *net.IPNet {
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func GetSubnet(ifaceName string) (p netip.Prefix) {
|
||||
netIfaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
log.Error("Could not get network interfaces info: %v", err)
|
||||
return nil
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
for _, netIface := range netIfaces {
|
||||
@@ -191,14 +222,14 @@ func GetSubnet(ifaceName string) *net.IPNet {
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return p
|
||||
}
|
||||
|
||||
// CheckPort checks if the port is available for binding. network is expected
|
||||
// to be one of "udp" and "tcp".
|
||||
func CheckPort(network string, ip net.IP, port int) (err error) {
|
||||
func CheckPort(network string, ipp netip.AddrPort) (err error) {
|
||||
var c io.Closer
|
||||
addr := netutil.IPPort{IP: ip, Port: port}.String()
|
||||
addr := ipp.String()
|
||||
switch network {
|
||||
case "tcp":
|
||||
c, err = net.Listen(network, addr)
|
||||
@@ -248,18 +279,23 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// BroadcastFromIPNet calculates the broadcast IP address for n.
|
||||
func BroadcastFromIPNet(n *net.IPNet) (dc net.IP) {
|
||||
dc = netutil.CloneIP(n.IP)
|
||||
|
||||
mask := n.Mask
|
||||
if mask == nil {
|
||||
mask = dc.DefaultMask()
|
||||
// BroadcastFromPref calculates the broadcast IP address for p.
|
||||
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
|
||||
bc = p.Addr().Unmap()
|
||||
if !bc.IsValid() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
for i, b := range mask {
|
||||
dc[i] |= ^b
|
||||
maskLen, addrLen := p.Bits(), bc.BitLen()
|
||||
if maskLen == addrLen {
|
||||
return bc
|
||||
}
|
||||
|
||||
return dc
|
||||
ipBytes := bc.AsSlice()
|
||||
for i := maskLen; i < addrLen; i++ {
|
||||
ipBytes[i/8] |= 1 << (7 - (i % 8))
|
||||
}
|
||||
bc, _ = netip.AddrFromSlice(ipBytes)
|
||||
|
||||
return bc
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd || openbsd
|
||||
// +build darwin freebsd openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -7,12 +6,13 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -22,17 +22,27 @@ import (
|
||||
const dhcpcdConf = "etc/dhcpcd.conf"
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
cnbs, err := unix.PrctlRetInt(
|
||||
res, err := unix.PrctlRetInt(
|
||||
unix.PR_CAP_AMBIENT,
|
||||
unix.PR_CAP_AMBIENT_RAISE,
|
||||
unix.CAP_NET_BIND_SERVICE,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINVAL) {
|
||||
// Older versions of Linux kernel do not support this. Print a
|
||||
// warning and check admin rights.
|
||||
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Don't check the error because it's always nil on Linux.
|
||||
adm, _ := aghos.HaveAdminRights()
|
||||
|
||||
return cnbs == 1 || adm, err
|
||||
return res == 1 || adm, nil
|
||||
}
|
||||
|
||||
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
|
||||
@@ -141,7 +151,7 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
|
||||
// interface through dhcpcd.conf.
|
||||
func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
ipNet := GetSubnet(ifaceName)
|
||||
if ipNet.IP == nil {
|
||||
if !ipNet.Addr().IsValid() {
|
||||
return errors.Error("can't get IP address")
|
||||
}
|
||||
|
||||
@@ -164,7 +174,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
|
||||
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
|
||||
// configure the interface to have a static IP.
|
||||
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
|
||||
func dhcpcdConfIface(ifaceName string, subnet netip.Prefix, gateway netip.Addr) (conf string) {
|
||||
b := &strings.Builder{}
|
||||
stringutil.WriteToBuilder(
|
||||
b,
|
||||
@@ -173,15 +183,15 @@ func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf stri
|
||||
" added by AdGuard Home.\ninterface ",
|
||||
ifaceName,
|
||||
"\nstatic ip_address=",
|
||||
ipNet.String(),
|
||||
subnet.String(),
|
||||
"\n",
|
||||
)
|
||||
|
||||
if gwIP != nil {
|
||||
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
|
||||
if gateway != (netip.Addr{}) {
|
||||
stringutil.WriteToBuilder(b, "static routers=", gateway.String(), "\n")
|
||||
}
|
||||
|
||||
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
|
||||
stringutil.WriteToBuilder(b, "static domain_name_servers=", subnet.Addr().String(), "\n\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
@@ -93,34 +93,29 @@ func TestGatewayIP(t *testing.T) {
|
||||
const cmd = "ip route show dev " + ifaceName
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
shell mapShell
|
||||
want net.IP
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
name: "success_v4",
|
||||
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
||||
want: net.IP{1, 2, 3, 4}.To16(),
|
||||
want: netip.MustParseAddr("1.2.3.4"),
|
||||
name: "success_v4",
|
||||
}, {
|
||||
name: "success_v6",
|
||||
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
||||
want: net.IP{
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0xFF, 0xFF,
|
||||
},
|
||||
want: netip.MustParseAddr("::ffff"),
|
||||
name: "success_v6",
|
||||
}, {
|
||||
name: "bad_output",
|
||||
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
||||
want: nil,
|
||||
want: netip.Addr{},
|
||||
name: "bad_output",
|
||||
}, {
|
||||
name: "err_runcmd",
|
||||
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
||||
want: nil,
|
||||
want: netip.Addr{},
|
||||
name: "err_runcmd",
|
||||
}, {
|
||||
name: "bad_code",
|
||||
shell: theOnlyCmd(cmd, 1, "", nil),
|
||||
want: nil,
|
||||
want: netip.Addr{},
|
||||
name: "bad_code",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -132,7 +127,7 @@ func TestGatewayIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInterfaceByIP(t *testing.T) {
|
||||
func TestInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
@@ -142,7 +137,7 @@ func TestGetInterfaceByIP(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := GetInterfaceByIP(ip)
|
||||
ifaceName := InterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
@@ -150,65 +145,61 @@ func TestGetInterfaceByIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBroadcastFromIPNet(t *testing.T) {
|
||||
known6 := net.IP{
|
||||
1, 2, 3, 4,
|
||||
5, 6, 7, 8,
|
||||
9, 10, 11, 12,
|
||||
13, 14, 15, 16,
|
||||
}
|
||||
known4 := netip.MustParseAddr("192.168.0.1")
|
||||
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
|
||||
|
||||
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
subnet *net.IPNet
|
||||
want net.IP
|
||||
pref netip.Prefix
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
pref: netip.PrefixFrom(known4, 0),
|
||||
want: fullBroadcast4,
|
||||
name: "full",
|
||||
subnet: &net.IPNet{
|
||||
IP: net.IP{192, 168, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 15, 0},
|
||||
},
|
||||
want: net.IP{192, 168, 240, 255},
|
||||
}, {
|
||||
name: "ipv6_no_mask",
|
||||
subnet: &net.IPNet{
|
||||
IP: known6,
|
||||
},
|
||||
pref: netip.PrefixFrom(known4, 20),
|
||||
want: netip.MustParseAddr("192.168.15.255"),
|
||||
name: "full",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
|
||||
want: known6,
|
||||
name: "ipv6_no_mask",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
|
||||
want: known4,
|
||||
name: "ipv4_no_mask",
|
||||
subnet: &net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 2},
|
||||
},
|
||||
want: net.IP{192, 168, 1, 255},
|
||||
}, {
|
||||
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||
want: fullBroadcast4,
|
||||
name: "unspecified",
|
||||
subnet: &net.IPNet{
|
||||
IP: net.IP{0, 0, 0, 0},
|
||||
Mask: net.IPMask{0, 0, 0, 0},
|
||||
},
|
||||
want: net.IPv4bcast,
|
||||
}, {
|
||||
pref: netip.Prefix{},
|
||||
want: netip.Addr{},
|
||||
name: "invalid",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
bc := BroadcastFromIPNet(tc.subnet)
|
||||
assert.True(t, bc.Equal(tc.want), bc)
|
||||
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPort(t *testing.T) {
|
||||
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||
|
||||
t.Run("tcp_bound", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:")
|
||||
l, err := net.Listen("tcp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
ipp := netutil.IPPortFromAddr(l.Addr())
|
||||
require.NotNil(t, ipp)
|
||||
require.NotNil(t, ipp.IP)
|
||||
require.NotZero(t, ipp.Port)
|
||||
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("tcp", ipp.IP, ipp.Port)
|
||||
err = CheckPort("tcp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
@@ -216,16 +207,15 @@ func TestCheckPort(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("udp_bound", func(t *testing.T) {
|
||||
conn, err := net.ListenPacket("udp", "127.0.0.1:")
|
||||
conn, err := net.ListenPacket("udp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
||||
|
||||
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
|
||||
require.NotNil(t, ipp)
|
||||
require.NotNil(t, ipp.IP)
|
||||
require.NotZero(t, ipp.Port)
|
||||
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("udp", ipp.IP, ipp.Port)
|
||||
err = CheckPort("udp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
@@ -233,12 +223,12 @@ func TestCheckPort(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("bad_network", func(t *testing.T) {
|
||||
err := CheckPort("bad_network", nil, 0)
|
||||
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can_bind", func(t *testing.T) {
|
||||
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
|
||||
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -322,18 +312,18 @@ func TestNetInterface_MarshalJSON(t *testing.T) {
|
||||
`"mtu":1500` +
|
||||
`}` + "\n"
|
||||
|
||||
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
|
||||
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
|
||||
require.True(t, ok)
|
||||
|
||||
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
|
||||
net4 := netip.PrefixFrom(ip4, 24)
|
||||
net6 := netip.PrefixFrom(ip6, 8)
|
||||
|
||||
iface := &NetInterface{
|
||||
Addresses: []net.IP{ip4, ip6},
|
||||
Subnets: []*net.IPNet{{
|
||||
IP: ip4.Mask(mask4),
|
||||
Mask: mask4,
|
||||
}, {
|
||||
IP: ip6.Mask(mask6),
|
||||
Mask: mask6,
|
||||
}},
|
||||
Addresses: []netip.Addr{ip4, ip6},
|
||||
Subnets: []netip.Prefix{net4, net6},
|
||||
Name: "iface0",
|
||||
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
Flags: net.FlagUp | net.FlagMulticast,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd || freebsd || linux || darwin
|
||||
// +build openbsd freebsd linux darwin
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type SystemResolvers interface {
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
|
||||
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
hostGenFunc HostGenFunc,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -7,6 +6,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -18,9 +18,8 @@ func createTestSystemResolversImpl(
|
||||
t.Helper()
|
||||
|
||||
sr := createTestSystemResolvers(t, hostGenFunc)
|
||||
require.IsType(t, (*systemResolvers)(nil), sr)
|
||||
|
||||
return sr.(*systemResolvers)
|
||||
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -62,9 +61,8 @@ func writeExit(w io.WriteCloser) {
|
||||
// scanAddrs scans the DNS addresses from nslookup's output. The expected
|
||||
// output of nslookup looks like this:
|
||||
//
|
||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||
// Address: 192.168.1.1
|
||||
//
|
||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||
// Address: 192.168.1.1
|
||||
func scanAddrs(s *bufio.Scanner) (addrs []string) {
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build mips || mips64
|
||||
// +build mips mips64
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
|
||||
// +build amd64 386 arm arm64 mipsle mips64le ppc64le
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
|
||||
57
internal/aghos/filewalker_internal_test.go
Normal file
57
internal/aghos/filewalker_internal_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// errFS is an fs.FS implementation, method Open of which always returns
|
||||
// errFSOpen.
|
||||
type errFS struct{}
|
||||
|
||||
// errFSOpen is returned from errFS.Open.
|
||||
const errFSOpen errors.Error = "test open error"
|
||||
|
||||
// Open implements the fs.FS interface for *errFS. fsys is always nil and err
|
||||
// is always errFSOpen.
|
||||
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
|
||||
return nil, errFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
@@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
@@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// UnsupportedError is returned by functions and methods when a particular
|
||||
@@ -60,9 +62,8 @@ const MaxCmdOutputSize = 64 * 1024
|
||||
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
|
||||
cmd := exec.Command(command, arguments...)
|
||||
out, err := cmd.Output()
|
||||
if len(out) > MaxCmdOutputSize {
|
||||
out = out[:MaxCmdOutputSize]
|
||||
}
|
||||
|
||||
out = out[:mathutil.Min(len(out), MaxCmdOutputSize)]
|
||||
|
||||
if err != nil {
|
||||
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
|
||||
@@ -121,13 +122,12 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
|
||||
}
|
||||
|
||||
// parsePSOutput scans the output of ps searching the largest PID of the process
|
||||
// associated with cmdName ignoring PIDs from ignore. A valid line from
|
||||
// r should look like these:
|
||||
//
|
||||
// 123 ./example-cmd
|
||||
// 1230 some/base/path/example-cmd
|
||||
// 3210 example-cmd
|
||||
// associated with cmdName ignoring PIDs from ignore. A valid line from r
|
||||
// should look like these:
|
||||
//
|
||||
// 123 ./example-cmd
|
||||
// 1230 some/base/path/example-cmd
|
||||
// 3210 example-cmd
|
||||
func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum int, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@@ -137,14 +137,12 @@ func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum
|
||||
}
|
||||
|
||||
cur, aerr := strconv.Atoi(fields[0])
|
||||
if aerr != nil || cur < 0 || intIn(cur, ignore) {
|
||||
if aerr != nil || cur < 0 || slices.Contains(ignore, cur) {
|
||||
continue
|
||||
}
|
||||
|
||||
instNum++
|
||||
if cur > largest {
|
||||
largest = cur
|
||||
}
|
||||
largest = mathutil.Max(largest, cur)
|
||||
}
|
||||
if err = s.Err(); err != nil {
|
||||
return 0, 0, fmt.Errorf("scanning stdout: %w", err)
|
||||
@@ -153,27 +151,21 @@ func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum
|
||||
return largest, instNum, nil
|
||||
}
|
||||
|
||||
// intIn returns true if nums contains n.
|
||||
func intIn(n int, nums []int) (ok bool) {
|
||||
for _, nn := range nums {
|
||||
if n == nn {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenWrt returns true if host OS is OpenWrt.
|
||||
func IsOpenWrt() (ok bool) {
|
||||
return isOpenWrt()
|
||||
}
|
||||
|
||||
// RootDirFS returns the fs.FS rooted at the operating system's root.
|
||||
// RootDirFS returns the [fs.FS] rooted at the operating system's root. On
|
||||
// Windows it returns the fs.FS rooted at the volume of the system directory
|
||||
// (usually, C:).
|
||||
func RootDirFS() (fsys fs.FS) {
|
||||
// Use empty string since os.DirFS implicitly prepends a slash to it. This
|
||||
// behavior is undocumented but it currently works.
|
||||
return os.DirFS("")
|
||||
return rootDirFS()
|
||||
}
|
||||
|
||||
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
|
||||
func NotifyReconfigureSignal(c chan<- os.Signal) {
|
||||
notifyReconfigureSignal(c)
|
||||
}
|
||||
|
||||
// NotifyShutdownSignal notifies c on receiving shutdown signals.
|
||||
@@ -181,6 +173,11 @@ func NotifyShutdownSignal(c chan<- os.Signal) {
|
||||
notifyShutdownSignal(c)
|
||||
}
|
||||
|
||||
// IsReconfigureSignal returns true if sig is a reconfigure signal.
|
||||
func IsReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return isReconfigureSignal(sig)
|
||||
}
|
||||
|
||||
// IsShutdownSignal returns true if sig is a shutdown signal.
|
||||
func IsShutdownSignal(sig os.Signal) (ok bool) {
|
||||
return isShutdownSignal(sig)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || netbsd || openbsd
|
||||
// +build darwin netbsd openbsd
|
||||
//go:build darwin || openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
// +build darwin freebsd linux openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func rootDirFS() (fsys fs.FS) {
|
||||
return os.DirFS("/")
|
||||
}
|
||||
|
||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGHUP)
|
||||
}
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == unix.SIGHUP
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func rootDirFS() (fsys fs.FS) {
|
||||
// TODO(a.garipov): Use a better way if golang/go#44279 is ever resolved.
|
||||
sysDir, err := windows.GetSystemDirectory()
|
||||
if err != nil {
|
||||
log.Error("aghos: getting root filesystem: %s; using C:", err)
|
||||
|
||||
// Assume that C: is the safe default.
|
||||
return os.DirFS("C:")
|
||||
}
|
||||
|
||||
return os.DirFS(filepath.VolumeName(sysDir))
|
||||
}
|
||||
|
||||
func setRlimit(val uint64) (err error) {
|
||||
return Unsupported("setrlimit")
|
||||
}
|
||||
@@ -40,17 +55,23 @@ func isOpenWrt() (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, windows.SIGHUP)
|
||||
}
|
||||
|
||||
func notifyShutdownSignal(c chan<- os.Signal) {
|
||||
// syscall.SIGTERM is processed automatically. See go doc os/signal,
|
||||
// section Windows.
|
||||
signal.Notify(c, os.Interrupt)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == windows.SIGHUP
|
||||
}
|
||||
|
||||
func isShutdownSignal(sig os.Signal) (ok bool) {
|
||||
switch sig {
|
||||
case
|
||||
os.Interrupt,
|
||||
syscall.SIGTERM:
|
||||
case os.Interrupt, syscall.SIGTERM:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !(windows || plan9)
|
||||
// +build !windows,!plan9
|
||||
//go:build !windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows || plan9
|
||||
// +build windows plan9
|
||||
//go:build windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd || linux || netbsd || openbsd
|
||||
// +build darwin freebsd linux netbsd openbsd
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -3,21 +3,11 @@ package aghtest
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DiscardLogOutput runs tests with discarded logger output.
|
||||
func DiscardLogOutput(m *testing.M) {
|
||||
// TODO(e.burkov): Refactor code and tests to not use the global mutable
|
||||
// logger.
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
|
||||
// revert changes.
|
||||
func ReplaceLogWriter(t testing.TB, w io.Writer) {
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Exchanger is a mock aghnet.Exchanger implementation for tests.
|
||||
type Exchanger struct {
|
||||
Ups upstream.Upstream
|
||||
}
|
||||
|
||||
// Exchange implements aghnet.Exchanger interface for *Exchanger.
|
||||
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
if e.Ups == nil {
|
||||
e.Ups = &TestErrUpstream{}
|
||||
}
|
||||
|
||||
return e.Ups.Exchange(req)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
181
internal/aghtest/interface.go
Normal file
181
internal/aghtest/interface.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Interface Mocks
|
||||
//
|
||||
// Keep entities in this file in alphabetic order.
|
||||
|
||||
// Standard Library
|
||||
|
||||
// Package fs
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock [fs.FS] implementation for tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the [fs.FS] interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock [fs.GlobFS] implementation for tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock [fs.StatFS] implementation for tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// Package net
|
||||
|
||||
// type check
|
||||
var _ net.Listener = (*Listener)(nil)
|
||||
|
||||
// Listener is a mock [net.Listener] implementation for tests.
|
||||
type Listener struct {
|
||||
OnAccept func() (conn net.Conn, err error)
|
||||
OnAddr func() (addr net.Addr)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Accept implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
return l.OnAccept()
|
||||
}
|
||||
|
||||
// Addr implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Addr() (addr net.Addr) {
|
||||
return l.OnAddr()
|
||||
}
|
||||
|
||||
// Close implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Close() (err error) {
|
||||
return l.OnClose()
|
||||
}
|
||||
|
||||
// Module adguard-home
|
||||
|
||||
// Package aghos
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
||||
|
||||
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
|
||||
// Package agh
|
||||
|
||||
// type check
|
||||
var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
|
||||
|
||||
// ServiceWithConfig is a mock [agh.ServiceWithConfig] implementation for tests.
|
||||
type ServiceWithConfig[ConfigType any] struct {
|
||||
OnStart func() (err error)
|
||||
OnShutdown func(ctx context.Context) (err error)
|
||||
OnConfig func() (c ConfigType)
|
||||
}
|
||||
|
||||
// Start implements the [agh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[_]) Start() (err error) {
|
||||
return s.OnStart()
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) {
|
||||
return s.OnShutdown(ctx)
|
||||
}
|
||||
|
||||
// Config implements the [agh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
|
||||
return s.OnConfig()
|
||||
}
|
||||
|
||||
// Module dnsproxy
|
||||
|
||||
// Package upstream
|
||||
|
||||
// type check
|
||||
var _ upstream.Upstream = (*UpstreamMock)(nil)
|
||||
|
||||
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
|
||||
//
|
||||
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
|
||||
// rename it to just Upstream.
|
||||
type UpstreamMock struct {
|
||||
OnAddress func() (addr string)
|
||||
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Address() (addr string) {
|
||||
return u.OnAddress()
|
||||
}
|
||||
|
||||
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return u.OnExchange(req)
|
||||
}
|
||||
|
||||
// Close implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Close() (err error) {
|
||||
return u.OnClose()
|
||||
}
|
||||
3
internal/aghtest/interface_test.go
Normal file
3
internal/aghtest/interface_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package aghtest_test
|
||||
|
||||
// Put interface checks that cause import cycles here.
|
||||
@@ -1,46 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import "io/fs"
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock fs.FS implementation to use in tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the fs.FS interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock fs.StatFS implementation to use in tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the fs.StatFS interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock fs.GlobFS implementation to use in tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the fs.GlobFS interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
@@ -5,13 +5,19 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Additional Upstream Testing Utilities
|
||||
|
||||
// Upstream is a mock implementation of upstream.Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
|
||||
type Upstream struct {
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string][]string
|
||||
@@ -19,13 +25,11 @@ type Upstream struct {
|
||||
IPv4 map[string][]net.IP
|
||||
// IPv6 is a map of hostname to IPv6.
|
||||
IPv6 map[string][]net.IP
|
||||
// Reverse is a map of address to domain name.
|
||||
Reverse map[string][]string
|
||||
// Addr is the address for Address method.
|
||||
Addr string
|
||||
}
|
||||
|
||||
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||
var _ upstream.Upstream = (*Upstream)(nil)
|
||||
|
||||
// Exchange implements the [upstream.Upstream] interface for *Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Split further into handlers.
|
||||
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
@@ -59,10 +63,6 @@ func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
for _, ip := range u.IPv6[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
|
||||
}
|
||||
case dns.TypePTR:
|
||||
for _, name := range u.Reverse[name] {
|
||||
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
|
||||
}
|
||||
}
|
||||
if len(resp.Answer) == 0 {
|
||||
resp.SetRcode(m, dns.RcodeNameError)
|
||||
@@ -71,79 +71,157 @@ func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Address implements upstream.Upstream interface for *Upstream.
|
||||
// Address implements [upstream.Upstream] interface for *Upstream.
|
||||
func (u *Upstream) Address() string {
|
||||
return u.Addr
|
||||
return "todo.upstream.example"
|
||||
}
|
||||
|
||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestBlockUpstream struct {
|
||||
Hostname string
|
||||
|
||||
// lock protects reqNum.
|
||||
lock sync.RWMutex
|
||||
reqNum int
|
||||
|
||||
Block bool
|
||||
// Close implements [upstream.Upstream] interface for *Upstream.
|
||||
func (u *Upstream) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||
// pair.
|
||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.reqNum++
|
||||
|
||||
hash := sha256.Sum256([]byte(u.Hostname))
|
||||
hashToReturn := hex.EncodeToString(hash[:])
|
||||
if !u.Block {
|
||||
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
// MatchedResponse is a test helper that returns a response with answer if req
|
||||
// has question type qt, and target targ. Otherwise, it returns nil.
|
||||
//
|
||||
// req must not be nil and req.Question must have a length of 1. Answer is
|
||||
// interpreted in the following ways:
|
||||
//
|
||||
// - For A and AAAA queries, answer must be an IP address of the corresponding
|
||||
// protocol version.
|
||||
//
|
||||
// - For PTR queries, answer should be a domain name in the response.
|
||||
//
|
||||
// If the answer does not correspond to the question type, MatchedResponse panics.
|
||||
// Panics are used instead of [testing.TB], because the helper is intended to
|
||||
// use in [UpstreamMock.OnExchange] callbacks, which are usually called in a
|
||||
// separate goroutine.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding version with DNS class as well.
|
||||
func MatchedResponse(req *dns.Msg, qt uint16, targ, answer string) (resp *dns.Msg) {
|
||||
if req == nil || len(req.Question) != 1 {
|
||||
panic(fmt.Errorf("bad req: %+v", req))
|
||||
}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
m.Answer = []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
},
|
||||
Txt: []string{
|
||||
hashToReturn,
|
||||
},
|
||||
q := req.Question[0]
|
||||
targ = dns.Fqdn(targ)
|
||||
if q.Qclass != dns.ClassINET || q.Qtype != qt || q.Name != targ {
|
||||
return nil
|
||||
}
|
||||
|
||||
respHdr := dns.RR_Header{
|
||||
Name: targ,
|
||||
Rrtype: qt,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
}
|
||||
|
||||
resp = new(dns.Msg).SetReply(req)
|
||||
switch qt {
|
||||
case dns.TypeA:
|
||||
resp.Answer = mustAnsA(respHdr, answer)
|
||||
case dns.TypeAAAA:
|
||||
resp.Answer = mustAnsAAAA(respHdr, answer)
|
||||
case dns.TypePTR:
|
||||
resp.Answer = []dns.RR{&dns.PTR{
|
||||
Hdr: respHdr,
|
||||
Ptr: answer,
|
||||
}}
|
||||
default:
|
||||
panic(fmt.Errorf("aghtest: bad question type: %s", dns.Type(qt)))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// mustAnsA returns valid answer records if s is a valid IPv4 address.
|
||||
// Otherwise, mustAnsA panics.
|
||||
func mustAnsA(respHdr dns.RR_Header, s string) (ans []dns.RR) {
|
||||
ip, err := netip.ParseAddr(s)
|
||||
if err != nil || !ip.Is4() {
|
||||
panic(fmt.Errorf("aghtest: bad A answer: %+v", s))
|
||||
}
|
||||
|
||||
return []dns.RR{&dns.A{
|
||||
Hdr: respHdr,
|
||||
A: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
|
||||
// mustAnsAAAA returns valid answer records if s is a valid IPv6 address.
|
||||
// Otherwise, mustAnsAAAA panics.
|
||||
func mustAnsAAAA(respHdr dns.RR_Header, s string) (ans []dns.RR) {
|
||||
ip, err := netip.ParseAddr(s)
|
||||
if err != nil || !ip.Is6() {
|
||||
panic(fmt.Errorf("aghtest: bad AAAA answer: %+v", s))
|
||||
}
|
||||
|
||||
return []dns.RR{&dns.AAAA{
|
||||
Hdr: respHdr,
|
||||
AAAA: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
|
||||
// NewUpstreamMock returns an [*UpstreamMock], fields OnAddress and OnClose of
|
||||
// which are set to stubs that return "upstream.example" and nil respectively.
|
||||
// The field OnExchange is set to onExc.
|
||||
func NewUpstreamMock(onExc func(req *dns.Msg) (resp *dns.Msg, err error)) (u *UpstreamMock) {
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "upstream.example" },
|
||||
OnExchange: onExc,
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
|
||||
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
|
||||
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
|
||||
// a different hash.
|
||||
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
|
||||
hash := sha256.Sum256([]byte(hostname))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
if !shouldBlock {
|
||||
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
}
|
||||
|
||||
ans := &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "",
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Txt: []string{hashStr},
|
||||
}
|
||||
respTmpl := &dns.Msg{
|
||||
Answer: []dns.RR{ans},
|
||||
}
|
||||
|
||||
return m, nil
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "sbpc.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = respTmpl.Copy()
|
||||
resp.SetReply(req)
|
||||
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestBlockUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
// ErrUpstream is the error returned from the [*UpstreamMock] created by
|
||||
// [NewErrorUpstream].
|
||||
const ErrUpstream errors.Error = "test upstream error"
|
||||
|
||||
// RequestsCount returns the number of handled requests. It's safe for
|
||||
// concurrent use.
|
||||
func (u *TestBlockUpstream) RequestsCount() int {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
return u.reqNum
|
||||
}
|
||||
|
||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestErrUpstream struct {
|
||||
// The error returned by Exchange may be unwrapped to the Err.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("errupstream: %w", u.Err)
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestErrUpstream) Address() string {
|
||||
return ""
|
||||
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
|
||||
// its Exchange method.
|
||||
func NewErrorUpstream() (u *UpstreamMock) {
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "error.upstream.example" },
|
||||
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return nil, errors.Error("test upstream error")
|
||||
},
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,50 @@
|
||||
// Package aghtls contains utilities for work with TLS.
|
||||
package aghtls
|
||||
|
||||
import "crypto/tls"
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// init makes sure that the cipher name map is filled.
|
||||
//
|
||||
// TODO(a.garipov): Propose a similar API to crypto/tls.
|
||||
func init() {
|
||||
suites := tls.CipherSuites()
|
||||
cipherSuites = make(map[string]uint16, len(suites))
|
||||
for _, s := range suites {
|
||||
cipherSuites[s.Name] = s.ID
|
||||
}
|
||||
|
||||
log.Debug("tls: known ciphers: %q", cipherSuites)
|
||||
}
|
||||
|
||||
// cipherSuites are a name-to-ID mapping of cipher suites from crypto/tls. It
|
||||
// is filled by init. It must not be modified.
|
||||
var cipherSuites map[string]uint16
|
||||
|
||||
// ParseCiphers parses a slice of cipher suites from cipher names.
|
||||
func ParseCiphers(cipherNames []string) (cipherIDs []uint16, err error) {
|
||||
if cipherNames == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cipherIDs = make([]uint16, 0, len(cipherNames))
|
||||
for _, name := range cipherNames {
|
||||
id, ok := cipherSuites[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown cipher %q", name)
|
||||
}
|
||||
|
||||
cipherIDs = append(cipherIDs, id)
|
||||
}
|
||||
|
||||
return cipherIDs, nil
|
||||
}
|
||||
|
||||
// SaferCipherSuites returns a set of default cipher suites with vulnerable and
|
||||
// weak cipher suites removed.
|
||||
@@ -28,3 +71,19 @@ func SaferCipherSuites() (safe []uint16) {
|
||||
|
||||
return safe
|
||||
}
|
||||
|
||||
// CertificateHasIP returns true if cert has at least a single IP address among
|
||||
// its subjectAltNames.
|
||||
func CertificateHasIP(cert *x509.Certificate) (ok bool) {
|
||||
if len(cert.IPAddresses) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, name := range cert.DNSNames {
|
||||
if _, err := netip.ParseAddr(name); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
56
internal/aghtls/aghtls_test.go
Normal file
56
internal/aghtls/aghtls_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package aghtls_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestParseCiphers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
want []uint16
|
||||
in []string
|
||||
}{{
|
||||
name: "nil",
|
||||
wantErrMsg: "",
|
||||
want: nil,
|
||||
in: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
want: []uint16{},
|
||||
in: []string{},
|
||||
}, {}, {
|
||||
name: "one",
|
||||
wantErrMsg: "",
|
||||
want: []uint16{tls.TLS_AES_128_GCM_SHA256},
|
||||
in: []string{"TLS_AES_128_GCM_SHA256"},
|
||||
}, {
|
||||
name: "several",
|
||||
wantErrMsg: "",
|
||||
want: []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384},
|
||||
in: []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"},
|
||||
}, {
|
||||
name: "bad",
|
||||
wantErrMsg: `unknown cipher "bad_cipher"`,
|
||||
want: nil,
|
||||
in: []string{"bad_cipher"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := aghtls.ParseCiphers(tc.in)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
14
internal/aghtls/root.go
Normal file
14
internal/aghtls/root.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package aghtls
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// SystemRootCAs tries to load root certificates from the operating system. It
|
||||
// returns nil in case nothing is found so that Go' crypto/x509 can use its
|
||||
// default algorithm to find system root CA list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/1311.
|
||||
func SystemRootCAs() (roots *x509.CertPool) {
|
||||
return rootCAs()
|
||||
}
|
||||
56
internal/aghtls/root_linux.go
Normal file
56
internal/aghtls/root_linux.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build linux
|
||||
|
||||
package aghtls
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func rootCAs() (roots *x509.CertPool) {
|
||||
// Directories with the system root certificates, which aren't supported by
|
||||
// Go's crypto/x509.
|
||||
dirs := []string{
|
||||
// Entware.
|
||||
"/opt/etc/ssl/certs",
|
||||
}
|
||||
|
||||
roots = x509.NewCertPool()
|
||||
for _, dir := range dirs {
|
||||
dirEnts, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Improve error handling here and in other places.
|
||||
log.Error("aghtls: opening directory %q: %s", dir, err)
|
||||
}
|
||||
|
||||
var rootsAdded bool
|
||||
for _, de := range dirEnts {
|
||||
var certData []byte
|
||||
rootFile := filepath.Join(dir, de.Name())
|
||||
certData, err = os.ReadFile(rootFile)
|
||||
if err != nil {
|
||||
log.Error("aghtls: reading root cert: %s", err)
|
||||
} else {
|
||||
if roots.AppendCertsFromPEM(certData) {
|
||||
rootsAdded = true
|
||||
} else {
|
||||
log.Error("aghtls: could not add root from %q", rootFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rootsAdded {
|
||||
return roots
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
9
internal/aghtls/root_others.go
Normal file
9
internal/aghtls/root_others.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux
|
||||
|
||||
package aghtls
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
func rootCAs() (roots *x509.CertPool) {
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd || openbsd
|
||||
// +build freebsd openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -10,11 +9,10 @@ import (
|
||||
// broadcast sends resp to the broadcast address specific for network interface.
|
||||
func (c *dhcpConn) broadcast(respData []byte, peer *net.UDPAddr) (n int, err error) {
|
||||
// Despite the fact that server4.NewIPv4UDPConn explicitly sets socket
|
||||
// options to allow broadcasting, it also binds the connection to a
|
||||
// specific interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo
|
||||
// causes errors while writing to the addresses that belong to another
|
||||
// interface. So, use the broadcast address specific for the interface
|
||||
// bound.
|
||||
// options to allow broadcasting, it also binds the connection to a specific
|
||||
// interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo causes errors
|
||||
// while writing to the addresses that belong to another interface. So, use
|
||||
// the broadcast address specific for the interface bound.
|
||||
peer.IP = c.bcastIP
|
||||
|
||||
return c.udpConn.WriteTo(respData, peer)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd || openbsd
|
||||
// +build freebsd openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
|
||||
// +build aix darwin dragonfly linux netbsd solaris
|
||||
//go:build darwin || linux
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
|
||||
// +build aix darwin dragonfly linux netbsd solaris
|
||||
//go:build darwin || linux
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
222
internal/dhcpd/config.go
Normal file
222
internal/dhcpd/config.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
// LocalDomainName is the domain name used for DHCP hosts. For example,
|
||||
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
|
||||
// when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"`
|
||||
}
|
||||
|
||||
// DHCPServer - DHCP server interface
|
||||
type DHCPServer interface {
|
||||
// ResetLeases resets leases.
|
||||
ResetLeases(leases []*Lease) (err error)
|
||||
// GetLeases returns deep clones of the current leases.
|
||||
GetLeases(flags GetLeasesFlags) (leases []*Lease)
|
||||
// AddStaticLease - add a static lease
|
||||
AddStaticLease(l *Lease) (err error)
|
||||
// RemoveStaticLease - remove a static lease
|
||||
RemoveStaticLease(l *Lease) (err error)
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
FindMACbyIP(ip net.IP) net.HardwareAddr
|
||||
|
||||
// WriteDiskConfig4 - copy disk configuration
|
||||
WriteDiskConfig4(c *V4ServerConf)
|
||||
// WriteDiskConfig6 - copy disk configuration
|
||||
WriteDiskConfig6(c *V6ServerConf)
|
||||
|
||||
// Start - start server
|
||||
Start() (err error)
|
||||
// Stop - stop server
|
||||
Stop() (err error)
|
||||
getLeasesRef() []*Lease
|
||||
}
|
||||
|
||||
// V4ServerConf - server configuration
|
||||
type V4ServerConf struct {
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
GatewayIP netip.Addr `yaml:"gateway_ip" json:"gateway_ip"`
|
||||
SubnetMask netip.Addr `yaml:"subnet_mask" json:"subnet_mask"`
|
||||
// broadcastIP is the broadcasting address pre-calculated from the
|
||||
// configured gateway IP and subnet mask.
|
||||
broadcastIP netip.Addr
|
||||
|
||||
// The first & the last IP address for dynamic leases
|
||||
// Bytes [0..2] of the last allowed IP address must match the first IP
|
||||
RangeStart netip.Addr `yaml:"range_start" json:"range_start"`
|
||||
RangeEnd netip.Addr `yaml:"range_end" json:"range_end"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
// IP conflict detector: time (ms) to wait for ICMP reply
|
||||
// 0: disable
|
||||
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
|
||||
|
||||
// Custom Options.
|
||||
//
|
||||
// Option with arbitrary hexadecimal data:
|
||||
// DEC_CODE hex HEX_DATA
|
||||
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
|
||||
//
|
||||
// Option with IP data (only 1 IP is supported):
|
||||
// DEC_CODE ip IP_ADDR
|
||||
Options []string `yaml:"options" json:"-"`
|
||||
|
||||
ipRange *ipRange
|
||||
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []netip.Addr // IPv4 addresses to return to DHCP clients as DNS server addresses
|
||||
|
||||
// subnet contains the DHCP server's subnet. The IP is the IP of the
|
||||
// gateway.
|
||||
subnet netip.Prefix
|
||||
|
||||
// notify is a way to signal to other components that leases have been
|
||||
// changed. notify must be called outside of locked sections, since the
|
||||
// clients might want to get the new data.
|
||||
//
|
||||
// TODO(a.garipov): This is utter madness and must be refactored. It just
|
||||
// begs for deadlock bugs and other nastiness.
|
||||
notify func(uint32)
|
||||
}
|
||||
|
||||
// errNilConfig is an error returned by validation method if the config is nil.
|
||||
const errNilConfig errors.Error = "nil config"
|
||||
|
||||
// ensureV4 returns an unmapped version of ip. An error is returned if the
|
||||
// passed ip is not an IPv4.
|
||||
func ensureV4(ip netip.Addr, kind string) (ip4 netip.Addr, err error) {
|
||||
ip4 = ip.Unmap()
|
||||
if !ip4.IsValid() || !ip4.Is4() {
|
||||
return netip.Addr{}, fmt.Errorf("%v is not an IPv4 %s", ip, kind)
|
||||
}
|
||||
|
||||
return ip4, nil
|
||||
}
|
||||
|
||||
// Validate returns an error if c is not a valid configuration.
|
||||
//
|
||||
// TODO(e.burkov): Don't set the config fields when the server itself will stop
|
||||
// containing the config.
|
||||
func (c *V4ServerConf) Validate() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if c == nil {
|
||||
return errNilConfig
|
||||
}
|
||||
|
||||
gatewayIP, err := ensureV4(c.GatewayIP, "address")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
subnetMask, err := ensureV4(c.SubnetMask, "subnet mask")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
maskLen, _ := net.IPMask(subnetMask.AsSlice()).Size()
|
||||
|
||||
c.subnet = netip.PrefixFrom(gatewayIP, maskLen)
|
||||
c.broadcastIP = aghnet.BroadcastFromPref(c.subnet)
|
||||
|
||||
rangeStart, err := ensureV4(c.RangeStart, "address")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
rangeEnd, err := ensureV4(c.RangeEnd, "address")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
c.ipRange, err = newIPRange(rangeStart.AsSlice(), rangeEnd.AsSlice())
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ipRange.contains(gatewayIP.AsSlice()) {
|
||||
return fmt.Errorf("gateway ip %v in the ip range: %v-%v",
|
||||
gatewayIP,
|
||||
c.RangeStart,
|
||||
c.RangeEnd,
|
||||
)
|
||||
}
|
||||
|
||||
if !c.subnet.Contains(rangeStart) {
|
||||
return fmt.Errorf("range start %v is outside network %v",
|
||||
c.RangeStart,
|
||||
c.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
if !c.subnet.Contains(rangeEnd) {
|
||||
return fmt.Errorf("range end %v is outside network %v",
|
||||
c.RangeEnd,
|
||||
c.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// V6ServerConf - server configuration
|
||||
type V6ServerConf struct {
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
// The first IP address for dynamic leases
|
||||
// The last allowed IP address ends with 0xff byte
|
||||
RangeStart net.IP `yaml:"range_start" json:"range_start"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
|
||||
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
|
||||
|
||||
// Server calls this function when leases data changes
|
||||
notify func(uint32)
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -16,6 +15,8 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/mdlayher/ethernet"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
@@ -49,16 +50,15 @@ type dhcpConn struct {
|
||||
}
|
||||
|
||||
// newDHCPConn creates the special connection for DHCP server.
|
||||
func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) {
|
||||
// Create the raw connection.
|
||||
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
|
||||
var ucast net.PacketConn
|
||||
if ucast, err = raw.ListenPacket(ifi, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
return nil, fmt.Errorf("creating raw udp connection: %w", err)
|
||||
}
|
||||
|
||||
// Create the UDP connection.
|
||||
var bcast net.PacketConn
|
||||
bcast, err = server4.NewIPv4UDPConn(ifi.Name, &net.UDPAddr{
|
||||
bcast, err = server4.NewIPv4UDPConn(iface.Name, &net.UDPAddr{
|
||||
// TODO(e.burkov): Listening on zeroes makes the server handle
|
||||
// requests from all the interfaces. Inspect the ways to
|
||||
// specify the interface-specific listening addresses.
|
||||
@@ -73,16 +73,16 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error)
|
||||
|
||||
return &dhcpConn{
|
||||
udpConn: bcast,
|
||||
bcastIP: s.conf.broadcastIP,
|
||||
bcastIP: s.conf.broadcastIP.AsSlice(),
|
||||
rawConn: ucast,
|
||||
srcMAC: ifi.HardwareAddr,
|
||||
srcIP: s.conf.dnsIPAddrs[0],
|
||||
srcMAC: iface.HardwareAddr,
|
||||
srcIP: s.conf.dnsIPAddrs[0].AsSlice(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wrapErrs is a helper to wrap the errors from two independent underlying
|
||||
// connections.
|
||||
func (c *dhcpConn) wrapErrs(action string, udpConnErr, rawConnErr error) (err error) {
|
||||
func (*dhcpConn) wrapErrs(action string, udpConnErr, rawConnErr error) (err error) {
|
||||
switch {
|
||||
case udpConnErr != nil && rawConnErr != nil:
|
||||
return errors.List(fmt.Sprintf("%s both connections", action), udpConnErr, rawConnErr)
|
||||
@@ -128,7 +128,7 @@ func (c *dhcpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
// connection.
|
||||
return c.udpConn.WriteTo(p, addr)
|
||||
default:
|
||||
return 0, fmt.Errorf("peer is of unexpected type %T", addr)
|
||||
return 0, fmt.Errorf("addr has an unexpected type %T", addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,32 +187,20 @@ func (c *dhcpConn) SetWriteDeadline(t time.Time) error {
|
||||
)
|
||||
}
|
||||
|
||||
// ipv4DefaultTTL is the default Time to Live value as recommended by
|
||||
// RFC-1700 (https://datatracker.ietf.org/doc/html/rfc1700) in seconds.
|
||||
// ipv4DefaultTTL is the default Time to Live value in seconds as recommended by
|
||||
// RFC-1700.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1700.
|
||||
const ipv4DefaultTTL = 64
|
||||
|
||||
// errInvalidPktDHCP is returned when the provided payload is not a valid DHCP
|
||||
// packet.
|
||||
const errInvalidPktDHCP errors.Error = "packet is not a valid dhcp packet"
|
||||
|
||||
// buildEtherPkt wraps the payload with IPv4, UDP and Ethernet frames. The
|
||||
// payload is expected to be an encoded DHCP packet.
|
||||
// buildEtherPkt wraps the payload with IPv4, UDP and Ethernet frames.
|
||||
// Validation of the payload is a caller's responsibility.
|
||||
func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []byte, err error) {
|
||||
dhcpLayer := gopacket.NewPacket(payload, layers.LayerTypeDHCPv4, gopacket.DecodeOptions{
|
||||
NoCopy: true,
|
||||
}).Layer(layers.LayerTypeDHCPv4)
|
||||
|
||||
// Check if the decoding succeeded and the resulting layer doesn't
|
||||
// contain any errors. It should guarantee panic-safe converting of the
|
||||
// layer into gopacket.SerializableLayer.
|
||||
if dhcpLayer == nil || dhcpLayer.LayerType() != layers.LayerTypeDHCPv4 {
|
||||
return nil, errInvalidPktDHCP
|
||||
}
|
||||
|
||||
udpLayer := &layers.UDP{
|
||||
SrcPort: dhcpv4.ServerPort,
|
||||
DstPort: dhcpv4.ClientPort,
|
||||
}
|
||||
|
||||
ipv4Layer := &layers.IPv4{
|
||||
Version: uint8(layers.IPProtocolIPv4),
|
||||
Flags: layers.IPv4DontFragment,
|
||||
@@ -225,6 +213,7 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
|
||||
// Ignore the error since it's only returned for invalid network layer's
|
||||
// type.
|
||||
_ = udpLayer.SetNetworkLayerForChecksum(ipv4Layer)
|
||||
|
||||
ethLayer := &layers.Ethernet{
|
||||
SrcMAC: c.srcMAC,
|
||||
DstMAC: peer.HardwareAddr,
|
||||
@@ -232,10 +221,19 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
|
||||
setts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}, ethLayer, ipv4Layer, udpLayer, dhcpLayer.(gopacket.SerializableLayer))
|
||||
}
|
||||
|
||||
err = gopacket.SerializeLayers(
|
||||
buf,
|
||||
setts,
|
||||
ethLayer,
|
||||
ipv4Layer,
|
||||
udpLayer,
|
||||
gopacket.Payload(payload),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serializing layers: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -11,9 +10,11 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
@@ -45,7 +46,7 @@ func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
n, err := conn.WriteTo(nil, &unexpectedAddrType{})
|
||||
require.Error(t, err)
|
||||
|
||||
testutil.AssertErrorMsg(t, "peer is of unexpected type *dhcpd.unexpectedAddrType", err)
|
||||
testutil.AssertErrorMsg(t, "addr has an unexpected type *dhcpd.unexpectedAddrType", err)
|
||||
assert.Zero(t, n)
|
||||
})
|
||||
}
|
||||
@@ -89,14 +90,13 @@ func TestBuildEtherPkt(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-serializable", func(t *testing.T) {
|
||||
t.Run("bad_payload", func(t *testing.T) {
|
||||
// Create an invalid DHCP packet.
|
||||
invalidPayload := []byte{1, 2, 3, 4}
|
||||
pkt, err := conn.buildEtherPkt(invalidPayload, nil)
|
||||
require.Error(t, err)
|
||||
pkt, err := conn.buildEtherPkt(invalidPayload, peer)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.ErrorIs(t, err, errInvalidPktDHCP)
|
||||
assert.Empty(t, pkt)
|
||||
assert.NotEmpty(t, pkt)
|
||||
})
|
||||
|
||||
t.Run("serializing_error", func(t *testing.T) {
|
||||
|
||||
@@ -32,7 +32,7 @@ func normalizeIP(ip net.IP) net.IP {
|
||||
}
|
||||
|
||||
// Load lease table from DB
|
||||
func (s *Server) dbLoad() (err error) {
|
||||
func (s *server) dbLoad() (err error) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
v6StaticLeases := []*Lease{}
|
||||
@@ -132,7 +132,7 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
|
||||
}
|
||||
|
||||
// Store lease table in DB
|
||||
func (s *Server) dbStore() (err error) {
|
||||
func (s *server) dbStore() (err error) {
|
||||
// Use an empty slice here as opposed to nil so that it doesn't write
|
||||
// "null" into the database file if leases are empty.
|
||||
leases := []leaseJSON{}
|
||||
|
||||
@@ -5,13 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,9 +20,19 @@ const (
|
||||
// TODO(e.burkov): Remove it when static leases determining mechanism
|
||||
// will be improved.
|
||||
leaseExpireStatic = 1
|
||||
|
||||
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
|
||||
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
|
||||
|
||||
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
|
||||
DefaultDHCPTimeoutICMP = 1000
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
// Currently used defaults for ifaceDNSAddrs.
|
||||
const (
|
||||
defaultMaxAttempts int = 10
|
||||
defaultBackoff time.Duration = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Lease contains the necessary information about a DHCP lease
|
||||
type Lease struct {
|
||||
@@ -45,8 +54,8 @@ func (l *Lease) Clone() (clone *Lease) {
|
||||
return &Lease{
|
||||
Expiry: l.Expiry,
|
||||
Hostname: l.Hostname,
|
||||
HWAddr: netutil.CloneMAC(l.HWAddr),
|
||||
IP: netutil.CloneIP(l.IP),
|
||||
HWAddr: slices.Clone(l.HWAddr),
|
||||
IP: slices.Clone(l.IP),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,30 +128,6 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
// LocalDomainName is the domain name used for DHCP hosts. For example,
|
||||
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
|
||||
// when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"`
|
||||
}
|
||||
|
||||
// OnLeaseChangedT is a callback for lease changes.
|
||||
type OnLeaseChangedT func(flags int)
|
||||
|
||||
@@ -156,8 +141,68 @@ const (
|
||||
LeaseChangedDBStore
|
||||
)
|
||||
|
||||
// Server - the current state of the DHCP server
|
||||
type Server struct {
|
||||
// GetLeasesFlags are the flags for GetLeases.
|
||||
type GetLeasesFlags uint8
|
||||
|
||||
// GetLeasesFlags values
|
||||
const (
|
||||
LeasesDynamic GetLeasesFlags = 0b01
|
||||
LeasesStatic GetLeasesFlags = 0b10
|
||||
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// Interface is the DHCP server that deals with both IP address families.
|
||||
type Interface interface {
|
||||
Start() (err error)
|
||||
Stop() (err error)
|
||||
Enabled() (ok bool)
|
||||
|
||||
Leases(flags GetLeasesFlags) (leases []*Lease)
|
||||
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
|
||||
FindMACbyIP(ip net.IP) (mac net.HardwareAddr)
|
||||
|
||||
WriteDiskConfig(c *ServerConfig)
|
||||
}
|
||||
|
||||
// MockInterface is a mock Interface implementation.
|
||||
//
|
||||
// TODO(e.burkov): Move to aghtest when the API stabilized.
|
||||
type MockInterface struct {
|
||||
OnStart func() (err error)
|
||||
OnStop func() (err error)
|
||||
OnEnabled func() (ok bool)
|
||||
OnLeases func(flags GetLeasesFlags) (leases []*Lease)
|
||||
OnSetOnLeaseChanged func(f OnLeaseChangedT)
|
||||
OnFindMACbyIP func(ip net.IP) (mac net.HardwareAddr)
|
||||
OnWriteDiskConfig func(c *ServerConfig)
|
||||
}
|
||||
|
||||
var _ Interface = (*MockInterface)(nil)
|
||||
|
||||
// Start implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Start() (err error) { return s.OnStart() }
|
||||
|
||||
// Stop implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Stop() (err error) { return s.OnStop() }
|
||||
|
||||
// Enabled implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Enabled() (ok bool) { return s.OnEnabled() }
|
||||
|
||||
// Leases implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Leases(flags GetLeasesFlags) (ls []*Lease) { return s.OnLeases(flags) }
|
||||
|
||||
// SetOnLeaseChanged implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) SetOnLeaseChanged(f OnLeaseChangedT) { s.OnSetOnLeaseChanged(f) }
|
||||
|
||||
// FindMACbyIP implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return s.OnFindMACbyIP(ip) }
|
||||
|
||||
// WriteDiskConfig implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) WriteDiskConfig(c *ServerConfig) { s.OnWriteDiskConfig(c) }
|
||||
|
||||
// server is the DHCP service that handles DHCPv4, DHCPv6, and HTTP API.
|
||||
type server struct {
|
||||
srv4 DHCPServer
|
||||
srv6 DHCPServer
|
||||
|
||||
@@ -169,27 +214,13 @@ type Server struct {
|
||||
onLeaseChanged []OnLeaseChangedT
|
||||
}
|
||||
|
||||
// GetLeasesFlags are the flags for GetLeases.
|
||||
type GetLeasesFlags uint8
|
||||
// type check
|
||||
var _ Interface = (*server)(nil)
|
||||
|
||||
// GetLeasesFlags values
|
||||
const (
|
||||
LeasesDynamic GetLeasesFlags = 0b0001
|
||||
LeasesStatic GetLeasesFlags = 0b0010
|
||||
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// ServerInterface is an interface for servers.
|
||||
type ServerInterface interface {
|
||||
Enabled() (ok bool)
|
||||
Leases(flags GetLeasesFlags) (leases []*Lease)
|
||||
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
|
||||
}
|
||||
|
||||
// Create - create object
|
||||
func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
s = &Server{
|
||||
// Create initializes and returns the DHCP server handling both address
|
||||
// families. It also registers the corresponding HTTP API endpoints.
|
||||
func Create(conf *ServerConfig) (s *server, err error) {
|
||||
s = &server{
|
||||
conf: &ServerConfig{
|
||||
ConfigModified: conf.ConfigModified,
|
||||
|
||||
@@ -204,35 +235,22 @@ func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
},
|
||||
}
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Our DHCP server doesn't work on Windows yet, so
|
||||
// signal that to the front with an HTTP 501.
|
||||
//
|
||||
// TODO(a.garipov): This needs refactoring. We
|
||||
// shouldn't even try and initialize a DHCP server on
|
||||
// Windows, but there are currently too many
|
||||
// interconnected parts--such as HTTP handlers and
|
||||
// frontend--to make that work properly.
|
||||
s.registerNotImplementedHandlers()
|
||||
} else {
|
||||
s.registerHandlers()
|
||||
}
|
||||
|
||||
webHandlersRegistered = true
|
||||
}
|
||||
// TODO(e.burkov): Don't register handlers, see TODO on
|
||||
// [aghhttp.RegisterFunc].
|
||||
s.registerHandlers()
|
||||
|
||||
v4conf := conf.Conf4
|
||||
v4conf.Enabled = s.conf.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
}
|
||||
|
||||
v4conf.InterfaceName = s.conf.InterfaceName
|
||||
v4conf.notify = s.onNotify
|
||||
s.srv4, err = v4Create(v4conf)
|
||||
v4conf.Enabled = s.conf.Enabled && v4conf.RangeStart.IsValid()
|
||||
|
||||
s.srv4, err = v4Create(&v4conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
|
||||
if v4conf.Enabled {
|
||||
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err)
|
||||
}
|
||||
|
||||
v6conf := conf.Conf6
|
||||
@@ -265,12 +283,12 @@ func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
}
|
||||
|
||||
// Enabled returns true when the server is enabled.
|
||||
func (s *Server) Enabled() (ok bool) {
|
||||
func (s *server) Enabled() (ok bool) {
|
||||
return s.conf.Enabled
|
||||
}
|
||||
|
||||
// resetLeases resets all leases in the lease database.
|
||||
func (s *Server) resetLeases() (err error) {
|
||||
func (s *server) resetLeases() (err error) {
|
||||
err = s.srv4.ResetLeases(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -287,7 +305,7 @@ func (s *Server) resetLeases() (err error) {
|
||||
}
|
||||
|
||||
// server calls this function after DB is updated
|
||||
func (s *Server) onNotify(flags uint32) {
|
||||
func (s *server) onNotify(flags uint32) {
|
||||
if flags == LeaseChangedDBStore {
|
||||
err := s.dbStore()
|
||||
if err != nil {
|
||||
@@ -301,31 +319,28 @@ func (s *Server) onNotify(flags uint32) {
|
||||
}
|
||||
|
||||
// SetOnLeaseChanged - set callback
|
||||
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
|
||||
func (s *server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
|
||||
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
|
||||
}
|
||||
|
||||
func (s *Server) notify(flags int) {
|
||||
if len(s.onLeaseChanged) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *server) notify(flags int) {
|
||||
for _, f := range s.onLeaseChanged {
|
||||
f(flags)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (s *Server) WriteDiskConfig(c *ServerConfig) {
|
||||
func (s *server) WriteDiskConfig(c *ServerConfig) {
|
||||
c.Enabled = s.conf.Enabled
|
||||
c.InterfaceName = s.conf.InterfaceName
|
||||
c.LocalDomainName = s.conf.LocalDomainName
|
||||
|
||||
s.srv4.WriteDiskConfig4(&c.Conf4)
|
||||
s.srv6.WriteDiskConfig6(&c.Conf6)
|
||||
}
|
||||
|
||||
// Start will listen on port 67 and serve DHCP requests.
|
||||
func (s *Server) Start() (err error) {
|
||||
func (s *server) Start() (err error) {
|
||||
err = s.srv4.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -340,7 +355,7 @@ func (s *Server) Start() (err error) {
|
||||
}
|
||||
|
||||
// Stop closes the listening UDP socket
|
||||
func (s *Server) Stop() (err error) {
|
||||
func (s *server) Stop() (err error) {
|
||||
err = s.srv4.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -356,12 +371,12 @@ func (s *Server) Stop() (err error) {
|
||||
|
||||
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
|
||||
// concurrent use.
|
||||
func (s *Server) Leases(flags GetLeasesFlags) (leases []*Lease) {
|
||||
func (s *server) Leases(flags GetLeasesFlags) (leases []*Lease) {
|
||||
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
func (s *server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
if ip.To4() != nil {
|
||||
return s.srv4.FindMACbyIP(ip)
|
||||
}
|
||||
@@ -369,6 +384,6 @@ func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
}
|
||||
|
||||
// AddStaticLease - add static v4 lease
|
||||
func (s *Server) AddStaticLease(l *Lease) error {
|
||||
func (s *server) AddStaticLease(l *Lease) error {
|
||||
return s.srv4.AddStaticLease(l)
|
||||
}
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func testNotify(flags uint32) {
|
||||
@@ -26,18 +25,18 @@ func testNotify(flags uint32) {
|
||||
// Leases database store/load.
|
||||
func TestDB(t *testing.T) {
|
||||
var err error
|
||||
s := Server{
|
||||
s := server{
|
||||
conf: &ServerConfig{
|
||||
DBFilePath: dbFilename,
|
||||
},
|
||||
}
|
||||
|
||||
s.srv4, err = v4Create(V4ServerConf{
|
||||
s.srv4, err = v4Create(&V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
RangeStart: netip.MustParseAddr("192.168.10.100"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.10.200"),
|
||||
GatewayIP: netip.MustParseAddr("192.168.10.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
notify: testNotify,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -88,32 +87,6 @@ func TestDB(t *testing.T) {
|
||||
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
|
||||
}
|
||||
|
||||
func TestIsValidSubnetMask(t *testing.T) {
|
||||
testCases := []struct {
|
||||
mask net.IP
|
||||
want bool
|
||||
}{{
|
||||
mask: net.IP{255, 255, 255, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 254, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 252, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 253, 0},
|
||||
}, {
|
||||
mask: net.IP{255, 255, 255, 1},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.mask.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeLeases(t *testing.T) {
|
||||
dynLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
@@ -140,41 +113,41 @@ func TestNormalizeLeases(t *testing.T) {
|
||||
func TestV4Server_badRange(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
gatewayIP netip.Addr
|
||||
subnetMask netip.Addr
|
||||
wantErrMsg string
|
||||
gatewayIP net.IP
|
||||
subnetMask net.IP
|
||||
}{{
|
||||
name: "gateway_in_range",
|
||||
name: "gateway_in_range",
|
||||
gatewayIP: netip.MustParseAddr("192.168.10.120"),
|
||||
subnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
|
||||
"192.168.10.20-192.168.10.200",
|
||||
gatewayIP: net.IP{192, 168, 10, 120},
|
||||
subnetMask: net.IP{255, 255, 255, 0},
|
||||
}, {
|
||||
name: "outside_range_start",
|
||||
name: "outside_range_start",
|
||||
gatewayIP: netip.MustParseAddr("192.168.10.1"),
|
||||
subnetMask: netip.MustParseAddr("255.255.255.240"),
|
||||
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
|
||||
"192.168.10.1/28",
|
||||
gatewayIP: net.IP{192, 168, 10, 1},
|
||||
subnetMask: net.IP{255, 255, 255, 240},
|
||||
}, {
|
||||
name: "outside_range_end",
|
||||
name: "outside_range_end",
|
||||
gatewayIP: netip.MustParseAddr("192.168.10.1"),
|
||||
subnetMask: netip.MustParseAddr("255.255.255.224"),
|
||||
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
|
||||
"192.168.10.1/27",
|
||||
gatewayIP: net.IP{192, 168, 10, 1},
|
||||
subnetMask: net.IP{255, 255, 255, 224},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 20},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
RangeStart: netip.MustParseAddr("192.168.10.20"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.10.200"),
|
||||
GatewayIP: tc.gatewayIP,
|
||||
SubnetMask: tc.subnetMask,
|
||||
notify: testNotify,
|
||||
}
|
||||
|
||||
_, err := v4Create(conf)
|
||||
_, err := v4Create(&conf)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
@@ -183,7 +156,7 @@ func TestV4Server_badRange(t *testing.T) {
|
||||
// cloneUDPAddr returns a deep copy of a.
|
||||
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
|
||||
return &net.UDPAddr{
|
||||
IP: netutil.CloneIP(a.IP),
|
||||
IP: slices.Clone(a.IP),
|
||||
Port: a.Port,
|
||||
Zone: a.Zone,
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("%v is not an IP address", ip)
|
||||
}
|
||||
|
||||
ip4 = ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
|
||||
}
|
||||
|
||||
return ip4, nil
|
||||
}
|
||||
|
||||
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
|
||||
func isValidSubnetMask(mask net.IP) bool {
|
||||
var n uint32
|
||||
n = binary.BigEndian.Uint32(mask)
|
||||
for i := 0; i != 32; i++ {
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
if (n & 0x80000000) == 0 {
|
||||
return false
|
||||
}
|
||||
n <<= 1
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,37 +1,36 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
type v4ServerConfJSON struct {
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
SubnetMask net.IP `json:"subnet_mask"`
|
||||
RangeStart net.IP `json:"range_start"`
|
||||
RangeEnd net.IP `json:"range_end"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
GatewayIP netip.Addr `json:"gateway_ip"`
|
||||
SubnetMask netip.Addr `json:"subnet_mask"`
|
||||
RangeStart netip.Addr `json:"range_start"`
|
||||
RangeEnd netip.Addr `json:"range_end"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
|
||||
func (j *v4ServerConfJSON) toServerConf() *V4ServerConf {
|
||||
if j == nil {
|
||||
return V4ServerConf{}
|
||||
return &V4ServerConf{}
|
||||
}
|
||||
|
||||
return V4ServerConf{
|
||||
return &V4ServerConf{
|
||||
GatewayIP: j.GatewayIP,
|
||||
SubnetMask: j.SubnetMask,
|
||||
RangeStart: j.RangeStart,
|
||||
@@ -41,8 +40,8 @@ func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
|
||||
}
|
||||
|
||||
type v6ServerConfJSON struct {
|
||||
RangeStart net.IP `json:"range_start"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
RangeStart netip.Addr `json:"range_start"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
|
||||
@@ -51,7 +50,7 @@ func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
|
||||
}
|
||||
|
||||
return V6ServerConf{
|
||||
RangeStart: j.RangeStart,
|
||||
RangeStart: j.RangeStart.AsSlice(),
|
||||
LeaseDuration: j.LeaseDuration,
|
||||
}
|
||||
}
|
||||
@@ -66,7 +65,7 @@ type dhcpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status := &dhcpStatusResponse{
|
||||
Enabled: s.conf.Enabled,
|
||||
IfaceName: s.conf.InterfaceName,
|
||||
@@ -80,41 +79,29 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status.Leases = s.Leases(LeasesDynamic)
|
||||
status.StaticLeases = s.Leases(LeasesStatic)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal DHCP status json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, status)
|
||||
}
|
||||
|
||||
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
|
||||
func (s *server) enableDHCP(ifaceName string) (code int, err error) {
|
||||
var hasStaticIP bool
|
||||
hasStaticIP, err = aghnet.IfaceHasStaticIP(ifaceName)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
// ErrPermission may happen here on Linux systems where
|
||||
// AdGuard Home is installed using Snap. That doesn't
|
||||
// necessarily mean that the machine doesn't have
|
||||
// a static IP, so we can assume that it has and go on.
|
||||
// If the machine doesn't, we'll get an error later.
|
||||
// ErrPermission may happen here on Linux systems where AdGuard Home
|
||||
// is installed using Snap. That doesn't necessarily mean that the
|
||||
// machine doesn't have a static IP, so we can assume that it has
|
||||
// and go on. If the machine doesn't, we'll get an error later.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
|
||||
//
|
||||
// TODO(a.garipov): I was thinking about moving this
|
||||
// into IfaceHasStaticIP, but then we wouldn't be able
|
||||
// to log it. Think about it more.
|
||||
// TODO(a.garipov): I was thinking about moving this into
|
||||
// IfaceHasStaticIP, but then we wouldn't be able to log it. Think
|
||||
// about it more.
|
||||
log.Info("error while checking static ip: %s; "+
|
||||
"assuming machine has static ip and going on", err)
|
||||
hasStaticIP = true
|
||||
} else if errors.Is(err, aghnet.ErrNoStaticIPInfo) {
|
||||
// Couldn't obtain a definitive answer. Assume static
|
||||
// IP an go on.
|
||||
// Couldn't obtain a definitive answer. Assume static IP an go on.
|
||||
log.Info("can't check for static ip; " +
|
||||
"assuming machine has static ip and going on")
|
||||
hasStaticIP = true
|
||||
@@ -149,34 +136,39 @@ type dhcpServerConfigJSON struct {
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV4(
|
||||
func (s *server) handleDHCPSetConfigV4(
|
||||
conf *dhcpServerConfigJSON,
|
||||
) (srv4 DHCPServer, enabled bool, err error) {
|
||||
) (srv DHCPServer, enabled bool, err error) {
|
||||
if conf.V4 == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
v4Conf := v4JSONToServerConf(conf.V4)
|
||||
v4Conf := conf.V4.toServerConf()
|
||||
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v4Conf.RangeStart) == 0 {
|
||||
if !v4Conf.RangeStart.IsValid() {
|
||||
v4Conf.Enabled = false
|
||||
}
|
||||
|
||||
enabled = v4Conf.Enabled
|
||||
v4Conf.InterfaceName = conf.InterfaceName
|
||||
|
||||
c4 := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&c4)
|
||||
// Set the default values for the fields not configurable via web API.
|
||||
c4 := &V4ServerConf{
|
||||
notify: s.onNotify,
|
||||
ICMPTimeout: s.conf.Conf4.ICMPTimeout,
|
||||
Options: s.conf.Conf4.Options,
|
||||
}
|
||||
|
||||
s.srv4.WriteDiskConfig4(c4)
|
||||
v4Conf.notify = c4.notify
|
||||
v4Conf.ICMPTimeout = c4.ICMPTimeout
|
||||
v4Conf.Options = c4.Options
|
||||
|
||||
srv4, err = v4Create(v4Conf)
|
||||
srv4, err := v4Create(v4Conf)
|
||||
|
||||
return srv4, enabled, err
|
||||
return srv4, srv4.enabled(), err
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV6(
|
||||
func (s *server) handleDHCPSetConfigV6(
|
||||
conf *dhcpServerConfigJSON,
|
||||
) (srv6 DHCPServer, enabled bool, err error) {
|
||||
if conf.V6 == nil {
|
||||
@@ -205,7 +197,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
return srv6, enabled, err
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
conf := &dhcpServerConfigJSON{}
|
||||
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
|
||||
conf.InterfaceName = s.conf.InterfaceName
|
||||
@@ -244,22 +236,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled != aghalg.NBNull {
|
||||
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
if conf.InterfaceName != "" {
|
||||
s.conf.InterfaceName = conf.InterfaceName
|
||||
}
|
||||
|
||||
if srv4 != nil {
|
||||
s.srv4 = srv4
|
||||
}
|
||||
|
||||
if srv6 != nil {
|
||||
s.srv6 = srv6
|
||||
}
|
||||
|
||||
s.setConfFromJSON(conf, srv4, srv6)
|
||||
s.conf.ConfigModified()
|
||||
|
||||
err = s.dbLoad()
|
||||
@@ -278,16 +255,36 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Flags string `json:"flags"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
Addrs4 []net.IP `json:"ipv4_addresses"`
|
||||
Addrs6 []net.IP `json:"ipv6_addresses"`
|
||||
// setConfFromJSON sets configuration parameters in s from the new configuration
|
||||
// decoded from JSON.
|
||||
func (s *server) setConfFromJSON(conf *dhcpServerConfigJSON, srv4, srv6 DHCPServer) {
|
||||
if conf.Enabled != aghalg.NBNull {
|
||||
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
if conf.InterfaceName != "" {
|
||||
s.conf.InterfaceName = conf.InterfaceName
|
||||
}
|
||||
|
||||
if srv4 != nil {
|
||||
s.srv4 = srv4
|
||||
}
|
||||
|
||||
if srv6 != nil {
|
||||
s.srv6 = srv6
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Flags string `json:"flags"`
|
||||
GatewayIP netip.Addr `json:"gateway_ip"`
|
||||
Addrs4 []netip.Addr `json:"ipv4_addresses"`
|
||||
Addrs6 []netip.Addr `json:"ipv6_addresses"`
|
||||
}
|
||||
|
||||
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]netInterfaceJSON{}
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
@@ -345,13 +342,18 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// ignore link-local
|
||||
//
|
||||
// TODO(e.burkov): Try to listen DHCP on LLA as well.
|
||||
if ipnet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
|
||||
|
||||
if ip4 := ipnet.IP.To4(); ip4 != nil {
|
||||
addr := netip.AddrFrom4(*(*[4]byte)(ip4))
|
||||
jsonIface.Addrs4 = append(jsonIface.Addrs4, addr)
|
||||
} else {
|
||||
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
|
||||
addr := netip.AddrFrom16(*(*[16]byte)(ipnet.IP))
|
||||
jsonIface.Addrs6 = append(jsonIface.Addrs6, addr)
|
||||
}
|
||||
}
|
||||
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
|
||||
@@ -406,31 +408,37 @@ type dhcpSearchResult struct {
|
||||
V6 dhcpSearchV6Result `json:"v6"`
|
||||
}
|
||||
|
||||
// Perform the following tasks:
|
||||
// . Search for another DHCP server running
|
||||
// . Check if a static IP is configured for the network interface
|
||||
// Respond with results
|
||||
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// findActiveServerReq is the JSON structure for the request to find active DHCP
|
||||
// servers.
|
||||
type findActiveServerReq struct {
|
||||
Interface string `json:"interface"`
|
||||
}
|
||||
|
||||
// handleDHCPFindActiveServer performs the following tasks:
|
||||
// 1. searches for another DHCP server in the network;
|
||||
// 2. check if a static IP is configured for the network interface;
|
||||
// 3. responds with the results.
|
||||
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
req := &findActiveServerReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Error(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ifaceName := strings.TrimSpace(string(body))
|
||||
ifaceName := req.Interface
|
||||
if ifaceName == "" {
|
||||
msg := "empty interface name specified"
|
||||
log.Error(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "empty interface name")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result := dhcpSearchResult{
|
||||
result := &dhcpSearchResult{
|
||||
V4: dhcpSearchV4Result{
|
||||
OtherServer: dhcpSearchOtherResult{
|
||||
Found: "no",
|
||||
@@ -455,6 +463,14 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
|
||||
}
|
||||
|
||||
setOtherDHCPResult(ifaceName, result)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, result)
|
||||
}
|
||||
|
||||
// setOtherDHCPResult sets the results of the check for another DHCP server in
|
||||
// result.
|
||||
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
|
||||
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
|
||||
if err4 != nil {
|
||||
result.V4.OtherServer.Found = "error"
|
||||
@@ -462,27 +478,16 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
} else if found4 {
|
||||
result.V4.OtherServer.Found = "yes"
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
result.V6.OtherServer.Found = "error"
|
||||
result.V6.OtherServer.Error = err6.Error()
|
||||
} else if found6 {
|
||||
result.V6.OtherServer.Found = "yes"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to marshal DHCP found json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
@@ -497,21 +502,16 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
|
||||
if ip4 == nil {
|
||||
var srv DHCPServer
|
||||
if ip4 := l.IP.To4(); ip4 != nil {
|
||||
l.IP = ip4
|
||||
srv = s.srv4
|
||||
} else {
|
||||
l.IP = l.IP.To16()
|
||||
|
||||
err = s.srv6.AddStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
|
||||
return
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
l.IP = ip4
|
||||
err = s.srv4.AddStaticLease(l)
|
||||
err = srv.AddStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -519,7 +519,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
@@ -556,14 +556,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
|
||||
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
|
||||
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
|
||||
DefaultDHCPTimeoutICMP = 1000
|
||||
)
|
||||
|
||||
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
@@ -587,7 +580,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
DBFilePath: s.conf.DBFilePath,
|
||||
}
|
||||
|
||||
v4conf := V4ServerConf{
|
||||
v4conf := &V4ServerConf{
|
||||
LeaseDuration: DefaultDHCPLeaseTTL,
|
||||
ICMPTimeout: DefaultDHCPTimeoutICMP,
|
||||
notify: s.onNotify,
|
||||
@@ -603,7 +596,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
s.conf.ConfigModified()
|
||||
}
|
||||
|
||||
func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.resetLeases()
|
||||
if err != nil {
|
||||
msg := "resetting leases: %s"
|
||||
@@ -613,7 +606,11 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
func (s *server) registerHandlers() {
|
||||
if s.conf.HTTPRegister == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
|
||||
@@ -623,44 +620,3 @@ func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
|
||||
}
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../home and
|
||||
// other packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// notImplemented returns a handler that replies to any request with an HTTP 501
|
||||
// Not Implemented status and a JSON error with the provided message msg.
|
||||
//
|
||||
// TODO(a.garipov): Either take the logger from the server after we've
|
||||
// refactored logging or make this not a method of *Server.
|
||||
func (s *Server) notImplemented(msg string) (f func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
|
||||
err := json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: msg,
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("writing 501 json response: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) registerNotImplementedHandlers() {
|
||||
h := s.notImplemented("dhcp is not supported on windows")
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", h)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", h)
|
||||
}
|
||||
48
internal/dhcpd/http_windows.go
Normal file
48
internal/dhcpd/http_windows.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../home and
|
||||
// other packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// notImplemented is a handler that replies to any request with an HTTP 501 Not
|
||||
// Implemented status and a JSON error with the provided message msg.
|
||||
//
|
||||
// TODO(a.garipov): Either take the logger from the server after we've
|
||||
// refactored logging or make this not a method of *Server.
|
||||
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
_ = aghhttp.WriteJSONResponseCode(w, r, http.StatusNotImplemented, &jsonError{
|
||||
Message: aghos.Unsupported("dhcp").Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// registerHandlers sets the handlers for DHCP HTTP API that always respond with
|
||||
// an HTTP 501, since DHCP server doesn't work on Windows yet.
|
||||
//
|
||||
// TODO(a.garipov): This needs refactoring. We shouldn't even try and
|
||||
// initialize a DHCP server on Windows, but there are currently too many
|
||||
// interconnected parts--such as HTTP handlers and frontend--to make that work
|
||||
// properly.
|
||||
func (s *server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
|
||||
}
|
||||
@@ -1,23 +1,28 @@
|
||||
//go:build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_notImplemented(t *testing.T) {
|
||||
s := &Server{}
|
||||
h := s.notImplemented("never!")
|
||||
s := &server{}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
h(w, r)
|
||||
s.notImplemented(w, r)
|
||||
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||
assert.Equal(t, `{"message":"never!"}`+"\n", w.Body.String())
|
||||
|
||||
wantStr := fmt.Sprintf("{%q:%q}", "message", aghos.Unsupported("dhcp"))
|
||||
assert.JSONEq(t, wantStr, w.Body.String())
|
||||
}
|
||||
@@ -27,6 +27,8 @@ const maxRangeLen = math.MaxUint32
|
||||
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr.
|
||||
func newIPRange(start, end net.IP) (r *ipRange, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -9,26 +8,31 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
)
|
||||
|
||||
// The aliases for DHCP option types available for explicit declaration.
|
||||
//
|
||||
// TODO(e.burkov): Add an option for classless routes.
|
||||
const (
|
||||
hexTyp = "hex"
|
||||
ipTyp = "ip"
|
||||
ipsTyp = "ips"
|
||||
textTyp = "text"
|
||||
typDel = "del"
|
||||
typBool = "bool"
|
||||
typDur = "dur"
|
||||
typHex = "hex"
|
||||
typIP = "ip"
|
||||
typIPs = "ips"
|
||||
typText = "text"
|
||||
typU8 = "u8"
|
||||
typU16 = "u16"
|
||||
)
|
||||
|
||||
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string. For
|
||||
// example:
|
||||
//
|
||||
// 252 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
|
||||
//
|
||||
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string.
|
||||
func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var data []byte
|
||||
data, err = hex.DecodeString(s)
|
||||
@@ -39,15 +43,12 @@ func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
|
||||
return dhcpv4.OptionGeneric{Data: data}, nil
|
||||
}
|
||||
|
||||
// parseDHCPOptionIP parses a DHCP option as a single IP address. For example:
|
||||
//
|
||||
// 6 ip 192.168.1.1
|
||||
//
|
||||
// parseDHCPOptionIP parses a DHCP option as a single IP address.
|
||||
func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var ip net.IP
|
||||
// All DHCPv4 options require IPv4, so don't put the 16-byte version.
|
||||
// Otherwise, the clients will receive weird data that looks like four
|
||||
// IPv4 addresses.
|
||||
// Otherwise, the clients will receive weird data that looks like four IPv4
|
||||
// addresses.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2688.
|
||||
if ip, err = netutil.ParseIPv4(s); err != nil {
|
||||
@@ -58,133 +59,343 @@ func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
|
||||
}
|
||||
|
||||
// parseDHCPOptionIPs parses a DHCP option as a comma-separates list of IP
|
||||
// addresses. For example:
|
||||
//
|
||||
// 6 ips 192.168.1.1,192.168.1.2
|
||||
//
|
||||
// addresses.
|
||||
func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var ips dhcpv4.IPs
|
||||
var ip net.IP
|
||||
var ip dhcpv4.OptionValue
|
||||
for i, ipStr := range strings.Split(s, ",") {
|
||||
// See notes in the ipDHCPOptionParserHandler.
|
||||
if ip, err = netutil.ParseIPv4(ipStr); err != nil {
|
||||
ip, err = parseDHCPOptionIP(ipStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing ip at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
ips = append(ips, net.IP(ip.(dhcpv4.IP)))
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// parseDHCPOptionText parses a DHCP option as a simple UTF-8 encoded
|
||||
// text. For example:
|
||||
//
|
||||
// 252 text http://192.168.1.1/wpad.dat
|
||||
//
|
||||
func parseDHCPOptionText(s string) (val dhcpv4.OptionValue) {
|
||||
return dhcpv4.OptionGeneric{Data: []byte(s)}
|
||||
// parseDHCPOptionDur parses a DHCP option as a duration in a human-readable
|
||||
// form.
|
||||
func parseDHCPOptionDur(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var v timeutil.Duration
|
||||
err = v.UnmarshalText([]byte(s))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding dur: %w", err)
|
||||
}
|
||||
|
||||
return dhcpv4.Duration(v.Duration), nil
|
||||
}
|
||||
|
||||
// parseDHCPOption parses an option. See the documentation of parseDHCPOption*
|
||||
// for more info.
|
||||
func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
|
||||
// parseDHCPOptionUint parses a DHCP option as an unsigned integer. bitSize is
|
||||
// expected to be 8 or 16.
|
||||
func parseDHCPOptionUint(s string, bitSize int) (val dhcpv4.OptionValue, err error) {
|
||||
var v uint64
|
||||
v, err = strconv.ParseUint(s, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding u%d: %w", bitSize, err)
|
||||
}
|
||||
|
||||
switch bitSize {
|
||||
case 8:
|
||||
return dhcpv4.OptionGeneric{Data: []byte{uint8(v)}}, nil
|
||||
case 16:
|
||||
return dhcpv4.Uint16(v), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported size of integer %d", bitSize)
|
||||
}
|
||||
}
|
||||
|
||||
// parseDHCPOptionBool parses a DHCP option as a boolean value. See
|
||||
// [strconv.ParseBool] for available values.
|
||||
func parseDHCPOptionBool(s string) (val dhcpv4.OptionValue, err error) {
|
||||
var v bool
|
||||
v, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding bool: %w", err)
|
||||
}
|
||||
|
||||
rawVal := [1]byte{}
|
||||
if v {
|
||||
rawVal[0] = 1
|
||||
}
|
||||
|
||||
return dhcpv4.OptionGeneric{Data: rawVal[:]}, nil
|
||||
}
|
||||
|
||||
// parseDHCPOptionVal parses a DHCP option value considering typ.
|
||||
func parseDHCPOptionVal(typ, valStr string) (val dhcpv4.OptionValue, err error) {
|
||||
switch typ {
|
||||
case typBool:
|
||||
val, err = parseDHCPOptionBool(valStr)
|
||||
case typDel:
|
||||
val = dhcpv4.OptionGeneric{Data: nil}
|
||||
case typDur:
|
||||
val, err = parseDHCPOptionDur(valStr)
|
||||
case typHex:
|
||||
val, err = parseDHCPOptionHex(valStr)
|
||||
case typIP:
|
||||
val, err = parseDHCPOptionIP(valStr)
|
||||
case typIPs:
|
||||
val, err = parseDHCPOptionIPs(valStr)
|
||||
case typText:
|
||||
val = dhcpv4.String(valStr)
|
||||
case typU8:
|
||||
val, err = parseDHCPOptionUint(valStr, 8)
|
||||
case typU16:
|
||||
val, err = parseDHCPOptionUint(valStr, 16)
|
||||
default:
|
||||
err = fmt.Errorf("unknown option type %q", typ)
|
||||
}
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
// parseDHCPOption parses an option. For the del option value is ignored. The
|
||||
// examples of possible option strings:
|
||||
//
|
||||
// - 1 bool true
|
||||
// - 2 del
|
||||
// - 3 dur 2h5s
|
||||
// - 4 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
|
||||
// - 5 ip 192.168.1.1
|
||||
// - 6 ips 192.168.1.1,192.168.1.2
|
||||
// - 7 text http://192.168.1.1/wpad.dat
|
||||
// - 8 u8 255
|
||||
// - 9 u16 65535
|
||||
func parseDHCPOption(s string) (code dhcpv4.OptionCode, val dhcpv4.OptionValue, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
parts := strings.SplitN(s, " ", 3)
|
||||
if len(parts) < 3 {
|
||||
return opt, errors.Error("need at least three fields")
|
||||
|
||||
var valStr string
|
||||
if pl := len(parts); pl < 3 {
|
||||
if pl < 2 || parts[1] != typDel {
|
||||
return nil, nil, errors.Error("bad option format")
|
||||
}
|
||||
} else {
|
||||
valStr = parts[2]
|
||||
}
|
||||
|
||||
var code64 uint64
|
||||
code64, err = strconv.ParseUint(parts[0], 10, 8)
|
||||
if err != nil {
|
||||
return opt, fmt.Errorf("parsing option code: %w", err)
|
||||
}
|
||||
|
||||
var optVal dhcpv4.OptionValue
|
||||
switch typ, val := parts[1], parts[2]; typ {
|
||||
case hexTyp:
|
||||
optVal, err = parseDHCPOptionHex(val)
|
||||
case ipTyp:
|
||||
optVal, err = parseDHCPOptionIP(val)
|
||||
case ipsTyp:
|
||||
optVal, err = parseDHCPOptionIPs(val)
|
||||
case textTyp:
|
||||
optVal = parseDHCPOptionText(val)
|
||||
default:
|
||||
return opt, fmt.Errorf("unknown option type %q", typ)
|
||||
return nil, nil, fmt.Errorf("parsing option code: %w", err)
|
||||
}
|
||||
|
||||
val, err = parseDHCPOptionVal(parts[1], valStr)
|
||||
if err != nil {
|
||||
return opt, err
|
||||
// Don't wrap an error since it's informative enough as is and there
|
||||
// also the deferred annotation.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return dhcpv4.Option{
|
||||
Code: dhcpv4.GenericOptionCode(code64),
|
||||
Value: optVal,
|
||||
}, nil
|
||||
return dhcpv4.GenericOptionCode(code64), val, nil
|
||||
}
|
||||
|
||||
// prepareOptions builds the set of DHCP options according to host requirements
|
||||
// document and values from conf.
|
||||
func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
|
||||
// Set default values for host configuration parameters listed in Appendix
|
||||
// A of RFC-2131. Those parameters, if requested by client, should be
|
||||
// returned with values defined by Host Requirements Document.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
|
||||
//
|
||||
// See also https://datatracker.ietf.org/doc/html/rfc1122,
|
||||
// https://datatracker.ietf.org/doc/html/rfc1123, and
|
||||
// https://datatracker.ietf.org/doc/html/rfc2132.
|
||||
opts = dhcpv4.Options{
|
||||
func (s *v4Server) prepareOptions() {
|
||||
// Set default values of host configuration parameters listed in Appendix A
|
||||
// of RFC-2131.
|
||||
s.implicitOpts = dhcpv4.OptionsFromList(
|
||||
// IP-Layer Per Host
|
||||
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
|
||||
|
||||
// Set the current recommended default time to live for the
|
||||
// Internet Protocol which is 64, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc1700.
|
||||
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
|
||||
// An Internet host that includes embedded gateway code MUST have a
|
||||
// configuration switch to disable the gateway function, and this switch
|
||||
// MUST default to the non-gateway mode.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.5.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionIPForwarding, []byte{0x0}),
|
||||
|
||||
// A host that supports non-local source-routing MUST have a
|
||||
// configurable switch to disable forwarding, and this switch MUST
|
||||
// default to disabled.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.5.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionNonLocalSourceRouting, []byte{0x0}),
|
||||
|
||||
// Do not set the Policy Filter Option since it only makes sense when
|
||||
// the non-local source routing is enabled.
|
||||
|
||||
// The minimum legal value is 576.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2132#section-4.4.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionMaximumDatagramAssemblySize,
|
||||
Value: dhcpv4.Uint16(576),
|
||||
},
|
||||
|
||||
// Set the current recommended default time to live for the Internet
|
||||
// Protocol which is 64.
|
||||
//
|
||||
// See https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml#ip-parameters-2.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionDefaultIPTTL, []byte{0x40}),
|
||||
|
||||
// For example, after the PTMU estimate is decreased, the timeout should
|
||||
// be set to 10 minutes; once this timer expires and a larger MTU is
|
||||
// attempted, the timeout can be set to a much smaller value.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-6.6.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionPathMTUAgingTimeout,
|
||||
Value: dhcpv4.Duration(10 * time.Minute),
|
||||
},
|
||||
|
||||
// There is a table describing the MTU values representing all major
|
||||
// data-link technologies in use in the Internet so that each set of
|
||||
// similar MTUs is associated with a plateau value equal to the lowest
|
||||
// MTU in the group.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1191#section-7.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionPathMTUPlateauTable, []byte{
|
||||
0x0, 0x44,
|
||||
0x1, 0x28,
|
||||
0x1, 0xFC,
|
||||
0x3, 0xEE,
|
||||
0x5, 0xD4,
|
||||
0x7, 0xD2,
|
||||
0x11, 0x0,
|
||||
0x1F, 0xE6,
|
||||
0x45, 0xFA,
|
||||
}),
|
||||
|
||||
// IP-Layer Per Interface
|
||||
|
||||
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
|
||||
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
|
||||
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
|
||||
// The all-routers address is preferred wherever possible, see
|
||||
// https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
|
||||
dhcpv4.OptionRouterSolicitationAddress.Code(): netutil.IPv4allrouter(),
|
||||
dhcpv4.OptionBroadcastAddress.Code(): netutil.IPv4bcast(),
|
||||
// Since nearly all networks in the Internet currently support an MTU of
|
||||
// 576 or greater, we strongly recommend the use of 576 for datagrams
|
||||
// sent to non-local networks.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.3.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionInterfaceMTU,
|
||||
Value: dhcpv4.Uint16(576),
|
||||
},
|
||||
|
||||
// Set the All Subnets Are Local Option to false since commonly the
|
||||
// connected hosts aren't expected to be multihomed.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.3.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionAllSubnetsAreLocal, []byte{0x00}),
|
||||
|
||||
// Set the Perform Mask Discovery Option to false to provide the subnet
|
||||
// mask by options only.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.2.9.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionPerformMaskDiscovery, []byte{0x00}),
|
||||
|
||||
// A system MUST NOT send an Address Mask Reply unless it is an
|
||||
// authoritative agent for address masks. An authoritative agent may be
|
||||
// a host or a gateway, but it MUST be explicitly configured as a
|
||||
// address mask agent.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.2.9.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionMaskSupplier, []byte{0x00}),
|
||||
|
||||
// Set the Perform Router Discovery Option to true as per Router
|
||||
// Discovery Document.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionPerformRouterDiscovery, []byte{0x01}),
|
||||
|
||||
// The all-routers address is preferred wherever possible.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionRouterSolicitationAddress,
|
||||
Value: dhcpv4.IP(netutil.IPv4allrouter()),
|
||||
},
|
||||
|
||||
// Don't set the Static Routes Option since it should be set up by
|
||||
// system administrator.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.1.2.
|
||||
|
||||
// A datagram with the destination address of limited broadcast will be
|
||||
// received by every host on the connected physical network but will not
|
||||
// be forwarded outside that network.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3.
|
||||
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
|
||||
|
||||
// Link-Layer Per Interface
|
||||
|
||||
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
|
||||
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
|
||||
// If the system does not dynamically negotiate use of the trailer
|
||||
// protocol on a per-destination basis, the default configuration MUST
|
||||
// disable the protocol.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.1.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionTrailerEncapsulation, []byte{0x00}),
|
||||
|
||||
// For proxy ARP situations, the timeout needs to be on the order of a
|
||||
// minute.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.2.1.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionArpCacheTimeout,
|
||||
Value: dhcpv4.Duration(time.Minute),
|
||||
},
|
||||
|
||||
// An Internet host that implements sending both the RFC-894 and the
|
||||
// RFC-1042 encapsulations MUST provide a configuration switch to select
|
||||
// which is sent, and this switch MUST default to RFC-894.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.3.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionEthernetEncapsulation, []byte{0x00}),
|
||||
|
||||
// TCP Per Host
|
||||
|
||||
dhcpv4.OptionTCPKeepaliveInterval.Code(): dhcpv4.Duration(0).ToBytes(),
|
||||
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
|
||||
// A fixed value must be at least big enough for the Internet diameter,
|
||||
// i.e., the longest possible path. A reasonable value is about twice
|
||||
// the diameter, to allow for continued Internet growth.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.7.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionDefaulTCPTTL,
|
||||
Value: dhcpv4.Duration(60 * time.Second),
|
||||
},
|
||||
|
||||
// The interval MUST be configurable and MUST default to no less than
|
||||
// two hours.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-4.2.3.6.
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionTCPKeepaliveInterval,
|
||||
Value: dhcpv4.Duration(2 * time.Hour),
|
||||
},
|
||||
|
||||
// Unfortunately, some misbehaved TCP implementations fail to respond to
|
||||
// a probe segment unless it contains data.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1122#section-4.2.3.6.
|
||||
dhcpv4.OptGeneric(dhcpv4.OptionTCPKeepaliveGarbage, []byte{0x01}),
|
||||
|
||||
// Values From Configuration
|
||||
dhcpv4.OptRouter(s.conf.GatewayIP.AsSlice()),
|
||||
|
||||
dhcpv4.OptionRouter.Code(): netutil.CloneIP(conf.subnet.IP),
|
||||
dhcpv4.OptionSubnetMask.Code(): dhcpv4.IPMask(conf.subnet.Mask).ToBytes(),
|
||||
}
|
||||
dhcpv4.OptSubnetMask(s.conf.SubnetMask.AsSlice()),
|
||||
)
|
||||
|
||||
// Set values for explicitly configured options.
|
||||
for i, o := range conf.Options {
|
||||
opt, err := parseDHCPOption(o)
|
||||
s.explicitOpts = dhcpv4.Options{}
|
||||
for i, o := range s.conf.Options {
|
||||
code, val, err := parseDHCPOption(o)
|
||||
if err != nil {
|
||||
log.Error("dhcpv4: bad option string at index %d: %s", i, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
opts.Update(opt)
|
||||
s.explicitOpts.Update(dhcpv4.Option{Code: code, Value: val})
|
||||
// Remove those from the implicit options.
|
||||
delete(s.implicitOpts, code.Code())
|
||||
}
|
||||
|
||||
return opts
|
||||
log.Debug("dhcpv4: implicit options:\n%s", s.implicitOpts.Summary(nil))
|
||||
log.Debug("dhcpv4: explicit options:\n%s", s.explicitOpts.Summary(nil))
|
||||
|
||||
if len(s.explicitOpts) == 0 {
|
||||
s.explicitOpts = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -7,171 +6,262 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseOpt(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantCode dhcpv4.OptionCode
|
||||
wantVal dhcpv4.OptionValue
|
||||
wantErrMsg string
|
||||
wantOpt dhcpv4.Option
|
||||
}{{
|
||||
name: "hex_success",
|
||||
in: "6 hex c0a80101c0a80102",
|
||||
name: "hex_success",
|
||||
in: "6 hex c0a80101c0a80102",
|
||||
wantCode: dhcpv4.GenericOptionCode(6),
|
||||
wantVal: dhcpv4.OptionGeneric{Data: []byte{
|
||||
0xC0, 0xA8, 0x01, 0x01,
|
||||
0xC0, 0xA8, 0x01, 0x02,
|
||||
}},
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x01},
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x02},
|
||||
),
|
||||
}, {
|
||||
name: "ip_success",
|
||||
in: "6 ip 1.2.3.4",
|
||||
wantCode: dhcpv4.GenericOptionCode(6),
|
||||
wantVal: dhcpv4.IP(net.IP{0x01, 0x02, 0x03, 0x04}),
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0x01, 0x02, 0x03, 0x04},
|
||||
),
|
||||
}, {
|
||||
name: "ip_fail_v6",
|
||||
in: "6 ip ::1234",
|
||||
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "ips_success",
|
||||
in: "6 ips 192.168.1.1,192.168.1.2",
|
||||
name: "ips_success",
|
||||
in: "6 ips 192.168.1.1,192.168.1.2",
|
||||
wantCode: dhcpv4.GenericOptionCode(6),
|
||||
wantVal: dhcpv4.IPs([]net.IP{
|
||||
{0xC0, 0xA8, 0x01, 0x01},
|
||||
{0xC0, 0xA8, 0x01, 0x02},
|
||||
}),
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptDNS(
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x01},
|
||||
net.IP{0xC0, 0xA8, 0x01, 0x02},
|
||||
),
|
||||
}, {
|
||||
name: "text_success",
|
||||
in: "252 text http://192.168.1.1/",
|
||||
wantCode: dhcpv4.GenericOptionCode(252),
|
||||
wantVal: dhcpv4.String("http://192.168.1.1/"),
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "del_success",
|
||||
in: "61 del",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionClientIdentifier),
|
||||
wantVal: dhcpv4.OptionGeneric{Data: nil},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bool_success",
|
||||
in: "19 bool true",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionIPForwarding),
|
||||
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x01}},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bool_success_false",
|
||||
in: "19 bool F",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionIPForwarding),
|
||||
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x00}},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "dur_success",
|
||||
in: "24 dur 2h5s",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionPathMTUAgingTimeout),
|
||||
wantVal: dhcpv4.Duration(2*time.Hour + 5*time.Second),
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "u8_success",
|
||||
in: "23 u8 64",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionDefaultIPTTL),
|
||||
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x40}},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "u16_success",
|
||||
in: "22 u16 1234",
|
||||
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionMaximumDatagramAssemblySize),
|
||||
wantVal: dhcpv4.Uint16(1234),
|
||||
wantErrMsg: "",
|
||||
wantOpt: dhcpv4.OptGeneric(
|
||||
dhcpv4.GenericOptionCode(252),
|
||||
[]byte("http://192.168.1.1/"),
|
||||
),
|
||||
}, {
|
||||
name: "bad_parts",
|
||||
in: "6 ip",
|
||||
wantErrMsg: `invalid option string "6 ip": need at least three fields`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: `invalid option string "6 ip": bad option format`,
|
||||
}, {
|
||||
name: "bad_code",
|
||||
in: "256 ip 1.1.1.1",
|
||||
name: "bad_code",
|
||||
in: "256 ip 1.1.1.1",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: `invalid option string "256 ip 1.1.1.1": parsing option code: ` +
|
||||
`strconv.ParseUint: parsing "256": value out of range`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "bad_type",
|
||||
in: "6 bad 1.1.1.1",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "hex_error",
|
||||
in: "6 hex ZZZ",
|
||||
name: "hex_error",
|
||||
in: "6 hex ZZZ",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: `invalid option string "6 hex ZZZ": decoding hex: ` +
|
||||
`encoding/hex: invalid byte: U+005A 'Z'`,
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "ip_error",
|
||||
in: "6 ip 1.2.3.x",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "ips_error",
|
||||
in: "6 ips 192.168.1.1,192.168.1.x",
|
||||
name: "ip_error_v6",
|
||||
in: "6 ip ::1234",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
|
||||
}, {
|
||||
name: "ips_error",
|
||||
in: "6 ips 192.168.1.1,192.168.1.x",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"6 ips 192.168.1.1,192.168.1.x\": " +
|
||||
"parsing ip at index 1: bad ipv4 address \"192.168.1.x\"",
|
||||
wantOpt: dhcpv4.Option{},
|
||||
}, {
|
||||
name: "bool_error",
|
||||
in: "19 bool yes",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"19 bool yes\": decoding bool: " +
|
||||
"strconv.ParseBool: parsing \"yes\": invalid syntax",
|
||||
}, {
|
||||
name: "dur_error",
|
||||
in: "24 dur 3y",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"24 dur 3y\": decoding dur: " +
|
||||
"unmarshaling duration: time: unknown unit \"y\" in duration \"3y\"",
|
||||
}, {
|
||||
name: "u8_error",
|
||||
in: "23 u8 256",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"23 u8 256\": decoding u8: " +
|
||||
"strconv.ParseUint: parsing \"256\": value out of range",
|
||||
}, {
|
||||
name: "u16_error",
|
||||
in: "23 u16 65536",
|
||||
wantCode: nil,
|
||||
wantVal: nil,
|
||||
wantErrMsg: "invalid option string \"23 u16 65536\": decoding u16: " +
|
||||
"strconv.ParseUint: parsing \"65536\": value out of range",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
opt, err := parseDHCPOption(tc.in)
|
||||
if tc.wantErrMsg != "" {
|
||||
require.Error(t, err)
|
||||
code, val, err := parseDHCPOption(tc.in)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
|
||||
assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
|
||||
assert.Equal(t, tc.wantCode, code)
|
||||
assert.Equal(t, tc.wantVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOptions(t *testing.T) {
|
||||
allDefault := dhcpv4.Options{
|
||||
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
|
||||
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
|
||||
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
|
||||
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
|
||||
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
|
||||
dhcpv4.OptionRouterSolicitationAddress.Code(): []byte{224, 0, 0, 2},
|
||||
dhcpv4.OptionBroadcastAddress.Code(): []byte{255, 255, 255, 255},
|
||||
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
|
||||
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
|
||||
dhcpv4.OptionTCPKeepaliveInterval.Code(): []byte{0, 0, 0, 0},
|
||||
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
|
||||
}
|
||||
oneIP, otherIP := net.IP{1, 2, 3, 4}, net.IP{5, 6, 7, 8}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
opts []string
|
||||
checks dhcpv4.Options
|
||||
name string
|
||||
wantExplicit dhcpv4.Options
|
||||
opts []string
|
||||
}{{
|
||||
name: "all_default",
|
||||
checks: allDefault,
|
||||
name: "all_default",
|
||||
wantExplicit: nil,
|
||||
opts: nil,
|
||||
}, {
|
||||
name: "configured_ip",
|
||||
wantExplicit: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptBroadcastAddress(oneIP),
|
||||
),
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
|
||||
},
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionBroadcastAddress.Code(): oneIP,
|
||||
},
|
||||
}, {
|
||||
name: "configured_ips",
|
||||
wantExplicit: dhcpv4.OptionsFromList(
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionDomainNameServer,
|
||||
Value: dhcpv4.IPs{oneIP, otherIP},
|
||||
},
|
||||
),
|
||||
opts: []string{
|
||||
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
|
||||
},
|
||||
checks: dhcpv4.Options{
|
||||
dhcpv4.OptionDomainNameServer.Code(): append(oneIP, otherIP...),
|
||||
},
|
||||
}, {
|
||||
name: "configured_bad",
|
||||
name: "configured_bad",
|
||||
wantExplicit: nil,
|
||||
opts: []string{
|
||||
"19 bool yes",
|
||||
"24 dur 3y",
|
||||
"23 u8 256",
|
||||
"23 u16 65536",
|
||||
"20 hex",
|
||||
"23 hex abc",
|
||||
"32 ips 1,2,3,4",
|
||||
"28 256.256.256.256",
|
||||
},
|
||||
checks: allDefault,
|
||||
}, {
|
||||
name: "configured_del",
|
||||
wantExplicit: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptBroadcastAddress(nil),
|
||||
),
|
||||
opts: []string{
|
||||
"28 del",
|
||||
},
|
||||
}, {
|
||||
name: "rewritten_del",
|
||||
wantExplicit: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
|
||||
),
|
||||
opts: []string{
|
||||
"28 del",
|
||||
"28 ip 255.255.255.255",
|
||||
},
|
||||
}, {
|
||||
name: "configured_and_del",
|
||||
wantExplicit: dhcpv4.OptionsFromList(
|
||||
dhcpv4.Option{
|
||||
Code: dhcpv4.OptionGeoConf,
|
||||
Value: dhcpv4.String("cba"),
|
||||
},
|
||||
),
|
||||
opts: []string{
|
||||
"123 text abc",
|
||||
"123 del",
|
||||
"123 text cba",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
opts := prepareOptions(V4ServerConf{
|
||||
// Just to avoid nil pointer dereference.
|
||||
subnet: &net.IPNet{},
|
||||
s := &v4Server{
|
||||
conf: &V4ServerConf{
|
||||
Options: tc.opts,
|
||||
})
|
||||
for c, v := range tc.checks {
|
||||
optVal := opts.Get(dhcpv4.GenericOptionCode(c))
|
||||
require.NotNil(t, optVal)
|
||||
},
|
||||
}
|
||||
|
||||
assert.Len(t, optVal, len(v))
|
||||
assert.Equal(t, v, optVal)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s.prepareOptions()
|
||||
|
||||
assert.Equal(t, tc.wantExplicit, s.explicitOpts)
|
||||
|
||||
for c := range s.explicitOpts {
|
||||
assert.NotContains(t, s.implicitOpts, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -65,39 +65,42 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
|
||||
}
|
||||
|
||||
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
|
||||
// Data scheme:
|
||||
//
|
||||
// ICMPv6:
|
||||
// type[1]
|
||||
// code[1]
|
||||
// chksum[2]
|
||||
// body (RouterAdvertisement):
|
||||
// Cur Hop Limit[1]
|
||||
// Flags[1]: MO......
|
||||
// Router Lifetime[2]
|
||||
// Reachable Time[4]
|
||||
// Retrans Timer[4]
|
||||
// Option=Prefix Information(3):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Prefix Length[1]
|
||||
// Flags[1]: LA......
|
||||
// Valid Lifetime[4]
|
||||
// Preferred Lifetime[4]
|
||||
// Reserved[4]
|
||||
// Prefix[16]
|
||||
// Option=MTU(5):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Reserved[2]
|
||||
// MTU[4]
|
||||
// Option=Source link-layer address(1):
|
||||
// Link-Layer Address[8/24]
|
||||
// Option=Recursive DNS Server(25):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Reserved[2]
|
||||
// Lifetime[4]
|
||||
// Addresses of IPv6 Recursive DNS Servers[16]
|
||||
// ICMPv6:
|
||||
// - type[1]
|
||||
// - code[1]
|
||||
// - chksum[2]
|
||||
// - body (RouterAdvertisement):
|
||||
// - Cur Hop Limit[1]
|
||||
// - Flags[1]: MO......
|
||||
// - Router Lifetime[2]
|
||||
// - Reachable Time[4]
|
||||
// - Retrans Timer[4]
|
||||
// - Option=Prefix Information(3):
|
||||
// - Type[1]
|
||||
// - Length * 8bytes[1]
|
||||
// - Prefix Length[1]
|
||||
// - Flags[1]: LA......
|
||||
// - Valid Lifetime[4]
|
||||
// - Preferred Lifetime[4]
|
||||
// - Reserved[4]
|
||||
// - Prefix[16]
|
||||
// - Option=MTU(5):
|
||||
// - Type[1]
|
||||
// - Length * 8bytes[1]
|
||||
// - Reserved[2]
|
||||
// - MTU[4]
|
||||
// - Option=Source link-layer address(1):
|
||||
// - Link-Layer Address[8/24]
|
||||
// - Option=Recursive DNS Server(25):
|
||||
// - Type[1]
|
||||
// - Length * 8bytes[1]
|
||||
// - Reserved[2]
|
||||
// - Lifetime[4]
|
||||
// - Addresses of IPv6 Recursive DNS Servers[16]
|
||||
//
|
||||
// TODO(a.garipov): Replace with an existing implementation from a dependency.
|
||||
func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
|
||||
var lla []byte
|
||||
lla, err = hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DHCPServer - DHCP server interface
|
||||
type DHCPServer interface {
|
||||
// ResetLeases resets leases.
|
||||
ResetLeases(leases []*Lease) (err error)
|
||||
// GetLeases returns deep clones of the current leases.
|
||||
GetLeases(flags GetLeasesFlags) (leases []*Lease)
|
||||
// AddStaticLease - add a static lease
|
||||
AddStaticLease(l *Lease) (err error)
|
||||
// RemoveStaticLease - remove a static lease
|
||||
RemoveStaticLease(l *Lease) (err error)
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
FindMACbyIP(ip net.IP) net.HardwareAddr
|
||||
|
||||
// WriteDiskConfig4 - copy disk configuration
|
||||
WriteDiskConfig4(c *V4ServerConf)
|
||||
// WriteDiskConfig6 - copy disk configuration
|
||||
WriteDiskConfig6(c *V6ServerConf)
|
||||
|
||||
// Start - start server
|
||||
Start() (err error)
|
||||
// Stop - stop server
|
||||
Stop() (err error)
|
||||
getLeasesRef() []*Lease
|
||||
}
|
||||
|
||||
// V4ServerConf - server configuration
|
||||
type V4ServerConf struct {
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
GatewayIP net.IP `yaml:"gateway_ip" json:"gateway_ip"`
|
||||
SubnetMask net.IP `yaml:"subnet_mask" json:"subnet_mask"`
|
||||
// broadcastIP is the broadcasting address pre-calculated from the
|
||||
// configured gateway IP and subnet mask.
|
||||
broadcastIP net.IP
|
||||
|
||||
// The first & the last IP address for dynamic leases
|
||||
// Bytes [0..2] of the last allowed IP address must match the first IP
|
||||
RangeStart net.IP `yaml:"range_start" json:"range_start"`
|
||||
RangeEnd net.IP `yaml:"range_end" json:"range_end"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
// IP conflict detector: time (ms) to wait for ICMP reply
|
||||
// 0: disable
|
||||
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
|
||||
|
||||
// Custom Options.
|
||||
//
|
||||
// Option with arbitrary hexadecimal data:
|
||||
// DEC_CODE hex HEX_DATA
|
||||
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
|
||||
//
|
||||
// Option with IP data (only 1 IP is supported):
|
||||
// DEC_CODE ip IP_ADDR
|
||||
Options []string `yaml:"options" json:"-"`
|
||||
|
||||
ipRange *ipRange
|
||||
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
|
||||
|
||||
// subnet contains the DHCP server's subnet. The IP is the IP of the
|
||||
// gateway.
|
||||
subnet *net.IPNet
|
||||
|
||||
// notify is a way to signal to other components that leases have
|
||||
// change. notify must be called outside of locked sections, since the
|
||||
// clients might want to get the new data.
|
||||
//
|
||||
// TODO(a.garipov): This is utter madness and must be refactored. It
|
||||
// just begs for deadlock bugs and other nastiness.
|
||||
notify func(uint32)
|
||||
}
|
||||
|
||||
// V6ServerConf - server configuration
|
||||
type V6ServerConf struct {
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
// The first IP address for dynamic leases
|
||||
// The last allowed IP address ends with 0xff byte
|
||||
RangeStart net.IP `yaml:"range_start" json:"range_start"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
|
||||
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
|
||||
|
||||
// Server calls this function when leases data changes
|
||||
notify func(uint32)
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Currently used defaults for ifaceDNSAddrs.
|
||||
const (
|
||||
defaultMaxAttempts int = 10
|
||||
|
||||
defaultBackoff time.Duration = 500 * time.Millisecond
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -9,15 +8,19 @@ import "net"
|
||||
|
||||
type winServer struct{}
|
||||
|
||||
func (s *winServer) ResetLeases(_ []*Lease) (err error) { return nil }
|
||||
func (s *winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
|
||||
func (s *winServer) getLeasesRef() []*Lease { return nil }
|
||||
func (s *winServer) AddStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (s *winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (s *winServer) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return nil }
|
||||
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {}
|
||||
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {}
|
||||
func (s *winServer) Start() (err error) { return nil }
|
||||
func (s *winServer) Stop() (err error) { return nil }
|
||||
func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil }
|
||||
func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil }
|
||||
// type check
|
||||
var _ DHCPServer = winServer{}
|
||||
|
||||
func (winServer) ResetLeases(_ []*Lease) (err error) { return nil }
|
||||
func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
|
||||
func (winServer) getLeasesRef() []*Lease { return nil }
|
||||
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) FindMACbyIP(_ net.IP) (mac net.HardwareAddr) { return nil }
|
||||
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
|
||||
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
|
||||
func (winServer) Start() (err error) { return nil }
|
||||
func (winServer) Stop() (err error) { return nil }
|
||||
|
||||
func v4Create(_ *V4ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
|
||||
func v6Create(_ V6ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
|
||||
|
||||
@@ -1,558 +0,0 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
}
|
||||
|
||||
// defaultV4ServerConf returns the default configuration for *v4Server to use in
|
||||
// tests.
|
||||
func defaultV4ServerConf() (conf V4ServerConf) {
|
||||
return V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
|
||||
// type of s is *v4Server.
|
||||
func defaultSrv(t *testing.T) (s DHCPServer) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
s, err = v4Create(defaultV4ServerConf())
|
||||
require.NoError(t, err)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
s := defaultSrv(t)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
|
||||
// Add static lease.
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}
|
||||
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.AddStaticLease(l)
|
||||
assert.Error(t, err)
|
||||
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
assert.True(t, ls[0].IsStatic())
|
||||
|
||||
// Try to remove static lease.
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: net.IP{192, 168, 10, 110},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Remove static lease.
|
||||
err = s.RemoveStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
}
|
||||
|
||||
func TestV4_AddReplace(t *testing.T) {
|
||||
sIface := defaultSrv(t)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
dynLeases := []Lease{{
|
||||
Hostname: "dynamic-1.local",
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}, {
|
||||
Hostname: "dynamic-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 151},
|
||||
}}
|
||||
|
||||
for i := range dynLeases {
|
||||
err := s.addLease(&dynLeases[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stLeases := []*Lease{{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}, {
|
||||
Hostname: "static-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 152},
|
||||
}}
|
||||
|
||||
for _, l := range stLeases {
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.True(t, l.IsStatic())
|
||||
}
|
||||
}
|
||||
|
||||
func TestV4Server_Process_optionsPriority(t *testing.T) {
|
||||
defaultIP := net.IP{192, 168, 1, 1}
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
// prepareSrv creates a *v4Server and sets the opt6IPs in the initial
|
||||
// configuration of the server as the value for DHCP option 6.
|
||||
prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
|
||||
t.Helper()
|
||||
|
||||
conf := defaultV4ServerConf()
|
||||
if len(opt6IPs) > 0 {
|
||||
b := &strings.Builder{}
|
||||
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
|
||||
for _, ip := range opt6IPs[1:] {
|
||||
stringutil.WriteToBuilder(b, ",", ip.String())
|
||||
}
|
||||
conf.Options = []string{b.String()}
|
||||
}
|
||||
|
||||
ss, err := v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
var ok bool
|
||||
s, ok = ss.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{defaultIP}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// checkResp creates a discovery message with DHCP option 6 requested amd
|
||||
// asserts the response to contain wantIPs in this option.
|
||||
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
|
||||
t.Helper()
|
||||
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
|
||||
dhcpv4.OptionDomainNameServer,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp *dhcpv4.DHCPv4
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s.process(req, resp)
|
||||
require.Equal(t, 1, res)
|
||||
|
||||
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
|
||||
require.NotEmpty(t, o)
|
||||
|
||||
wantData := []byte{}
|
||||
for _, ip := range wantIPs {
|
||||
wantData = append(wantData, ip...)
|
||||
}
|
||||
assert.Equal(t, o, wantData)
|
||||
}
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
s := prepareSrv(t, nil)
|
||||
|
||||
checkResp(t, s, []net.IP{defaultIP})
|
||||
})
|
||||
|
||||
t.Run("explicitly_configured", func(t *testing.T) {
|
||||
s := prepareSrv(t, []net.IP{knownIP, knownIP})
|
||||
|
||||
checkResp(t, s, []net.IP{knownIP, knownIP})
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4StaticLease_Get(t *testing.T) {
|
||||
sIface := defaultSrv(t)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
}
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
t.Run("discover", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
|
||||
dhcpv4.OptionDomainNameServer,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
|
||||
// Don't continue if we got any errors in the previous subtest.
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
t.Run("request", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewRequestFromOffer(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
|
||||
assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0]))
|
||||
|
||||
t.Run("check_lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4DynamicLease_Get(t *testing.T) {
|
||||
conf := defaultV4ServerConf()
|
||||
conf.Options = []string{
|
||||
"81 hex 303132",
|
||||
"82 ip 1.2.3.4",
|
||||
}
|
||||
|
||||
var err error
|
||||
sIface, err := v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
t.Run("discover", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
|
||||
dhcpv4.OptionFQDN,
|
||||
dhcpv4.OptionRelayAgentInformation,
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
|
||||
// Don't continue if we got any errors in the previous subtest.
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
|
||||
assert.Equal(t, s.conf.RangeStart, resp.YourIPAddr)
|
||||
assert.Equal(t, s.conf.GatewayIP, resp.ServerIdentifier())
|
||||
|
||||
router := resp.Router()
|
||||
require.Len(t, router, 1)
|
||||
|
||||
assert.Equal(t, s.conf.GatewayIP, router[0])
|
||||
|
||||
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
assert.Equal(t, []byte("012"), resp.Options.Get(dhcpv4.OptionFQDN))
|
||||
|
||||
rai := resp.RelayAgentInfo()
|
||||
require.NotNil(t, rai)
|
||||
assert.Equal(t, net.IP{1, 2, 3, 4}, net.IP(rai.ToBytes()))
|
||||
})
|
||||
|
||||
t.Run("request", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewRequestFromOffer(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
|
||||
|
||||
router := resp.Router()
|
||||
require.Len(t, router, 1)
|
||||
|
||||
assert.Equal(t, s.conf.GatewayIP, router[0])
|
||||
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
|
||||
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
|
||||
|
||||
// check lease
|
||||
t.Run("check_lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeHostname(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
hostname string
|
||||
wantErrMsg string
|
||||
want string
|
||||
}{{
|
||||
name: "success",
|
||||
hostname: "example.com",
|
||||
wantErrMsg: "",
|
||||
want: "example.com",
|
||||
}, {
|
||||
name: "success_empty",
|
||||
hostname: "",
|
||||
wantErrMsg: "",
|
||||
want: "",
|
||||
}, {
|
||||
name: "success_spaces",
|
||||
hostname: "my device 01",
|
||||
wantErrMsg: "",
|
||||
want: "my-device-01",
|
||||
}, {
|
||||
name: "success_underscores",
|
||||
hostname: "my_device_01",
|
||||
wantErrMsg: "",
|
||||
want: "my-device-01",
|
||||
}, {
|
||||
name: "error_part",
|
||||
hostname: "device !!!",
|
||||
wantErrMsg: "",
|
||||
want: "device",
|
||||
}, {
|
||||
name: "error_part_spaces",
|
||||
hostname: "device ! ! !",
|
||||
wantErrMsg: "",
|
||||
want: "device",
|
||||
}, {
|
||||
name: "error",
|
||||
hostname: "!!!",
|
||||
wantErrMsg: `normalizing "!!!": no valid parts`,
|
||||
want: "",
|
||||
}, {
|
||||
name: "error_spaces",
|
||||
hostname: "! ! !",
|
||||
wantErrMsg: `normalizing "! ! !": no valid parts`,
|
||||
want: "",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := normalizeHostname(tc.hostname)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakePacketConn is a mock implementation of net.PacketConn to simplify
|
||||
// testing.
|
||||
type fakePacketConn struct {
|
||||
// writeTo is used to substitute net.PacketConn's WriteTo method.
|
||||
writeTo func(p []byte, addr net.Addr) (n int, err error)
|
||||
// net.PacketConn is embedded here simply to make *fakePacketConn a
|
||||
// net.PacketConn without actually implementing all methods.
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
// WriteTo implements net.PacketConn interface for *fakePacketConn.
|
||||
func (fc *fakePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return fc.writeTo(p, addr)
|
||||
}
|
||||
|
||||
func TestV4Server_Send(t *testing.T) {
|
||||
s := &v4Server{}
|
||||
|
||||
var (
|
||||
defaultIP = net.IP{99, 99, 99, 99}
|
||||
knownIP = net.IP{4, 2, 4, 2}
|
||||
knownMAC = net.HardwareAddr{6, 5, 4, 3, 2, 1}
|
||||
)
|
||||
|
||||
defaultPeer := &net.UDPAddr{
|
||||
IP: defaultIP,
|
||||
// Use neither client nor server port to check it actually
|
||||
// changed.
|
||||
Port: dhcpv4.ClientPort + dhcpv4.ServerPort,
|
||||
}
|
||||
defaultResp := &dhcpv4.DHCPv4{}
|
||||
|
||||
testCases := []struct {
|
||||
want net.Addr
|
||||
req *dhcpv4.DHCPv4
|
||||
resp *dhcpv4.DHCPv4
|
||||
name string
|
||||
}{{
|
||||
name: "giaddr",
|
||||
req: &dhcpv4.DHCPv4{GatewayIPAddr: knownIP},
|
||||
resp: defaultResp,
|
||||
want: &net.UDPAddr{
|
||||
IP: knownIP,
|
||||
Port: dhcpv4.ServerPort,
|
||||
},
|
||||
}, {
|
||||
name: "nak",
|
||||
req: &dhcpv4.DHCPv4{},
|
||||
resp: &dhcpv4.DHCPv4{
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeNak),
|
||||
),
|
||||
},
|
||||
want: defaultPeer,
|
||||
}, {
|
||||
name: "ciaddr",
|
||||
req: &dhcpv4.DHCPv4{ClientIPAddr: knownIP},
|
||||
resp: &dhcpv4.DHCPv4{},
|
||||
want: &net.UDPAddr{
|
||||
IP: knownIP,
|
||||
Port: dhcpv4.ClientPort,
|
||||
},
|
||||
}, {
|
||||
name: "chaddr",
|
||||
req: &dhcpv4.DHCPv4{ClientHWAddr: knownMAC},
|
||||
resp: &dhcpv4.DHCPv4{YourIPAddr: knownIP},
|
||||
want: &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: knownMAC},
|
||||
yiaddr: knownIP,
|
||||
},
|
||||
}, {
|
||||
name: "who_are_you",
|
||||
req: &dhcpv4.DHCPv4{},
|
||||
resp: &dhcpv4.DHCPv4{},
|
||||
want: defaultPeer,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn := &fakePacketConn{
|
||||
writeTo: func(_ []byte, addr net.Addr) (_ int, _ error) {
|
||||
assert.Equal(t, tc.want, addr)
|
||||
|
||||
return 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
s.send(cloneUDPAddr(defaultPeer), conn, tc.req, tc.resp)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("giaddr_nak", func(t *testing.T) {
|
||||
req := &dhcpv4.DHCPv4{
|
||||
GatewayIPAddr: knownIP,
|
||||
}
|
||||
// Ensure the request is for unicast.
|
||||
req.SetUnicast()
|
||||
resp := &dhcpv4.DHCPv4{
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeNak),
|
||||
),
|
||||
}
|
||||
want := &net.UDPAddr{
|
||||
IP: req.GatewayIPAddr,
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
|
||||
conn := &fakePacketConn{
|
||||
writeTo: func(_ []byte, addr net.Addr) (n int, err error) {
|
||||
assert.Equal(t, want, addr)
|
||||
|
||||
return 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
s.send(cloneUDPAddr(defaultPeer), conn, req, resp)
|
||||
assert.True(t, resp.IsBroadcast())
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user