Merge branch 'master' into 4728-cap-check

This commit is contained in:
Eugene Burkov
2023-02-06 15:45:10 +03:00
630 changed files with 28309 additions and 38061 deletions

View File

@@ -1,32 +1,64 @@
// Package aghalg contains common generic algorithms and data structures.
//
// TODO(a.garipov): Update to use type parameters in Go 1.18.
// TODO(a.garipov): Move parts of this into golibs.
package aghalg
import (
"fmt"
"sort"
"golang.org/x/exp/constraints"
"golang.org/x/exp/slices"
)
// comparable is an alias for interface{}. Values passed as arguments of this
// type alias must be comparable.
// Coalesce returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
//
// TODO(a.garipov): Remove in Go 1.18.
type comparable = interface{}
// T is comparable, because Go currently doesn't have a comparableWithZeroValue
// constraint.
//
// TODO(a.garipov): Think of ways to merge with [CoalesceSlice].
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// CoalesceSlice returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns nil.
//
// TODO(a.garipov): Think of ways to merge with [Coalesce].
func CoalesceSlice[E any, S []E](values ...S) (res S) {
for _, v := range values {
if v != nil {
return v
}
}
return nil
}
// UniqChecker allows validating uniqueness of comparable items.
type UniqChecker map[comparable]int64
//
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
// Consider ways of making this constraint comparable instead.
type UniqChecker[T constraints.Ordered] map[T]int64
// Add adds a value to the validator. v must not be nil.
func (uc UniqChecker) Add(elems ...comparable) {
func (uc UniqChecker[T]) Add(elems ...T) {
for _, e := range elems {
uc[e]++
}
}
// Merge returns a checker containing data from both uc and other.
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
merged = make(UniqChecker, len(uc)+len(other))
func (uc UniqChecker[T]) Merge(other UniqChecker[T]) (merged UniqChecker[T]) {
merged = make(UniqChecker[T], len(uc)+len(other))
for elem, num := range uc {
merged[elem] += num
}
@@ -39,10 +71,8 @@ func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
}
// Validate returns an error enumerating all elements that aren't unique.
// isBefore is an optional sorting function to make the error message
// deterministic.
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
var dup []comparable
func (uc UniqChecker[T]) Validate() (err error) {
var dup []T
for elem, num := range uc {
if num > 1 {
dup = append(dup, elem)
@@ -53,23 +83,7 @@ func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err
return nil
}
if isBefore != nil {
sort.Slice(dup, func(i, j int) (less bool) {
return isBefore(dup[i], dup[j])
})
}
slices.Sort(dup)
return fmt.Errorf("duplicated values: %v", dup)
}
// IntIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type int.
func IntIsBefore(a, b comparable) (less bool) {
return a.(int) < b.(int)
}
// StringIsBefore is a helper sort function for UniqChecker.Validate.
// a and b must be of type string.
func StringIsBefore(a, b comparable) (less bool) {
return a.(string) < b.(string)
}

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/AdguardTeam/golibs/mathutil"
)
// NullBool is a nullable boolean. Use these in JSON requests and responses
@@ -33,11 +35,15 @@ func (nb NullBool) String() (s string) {
// BoolToNullBool converts a bool into a NullBool.
func BoolToNullBool(cond bool) (nb NullBool) {
if cond {
return NBTrue
}
return NBFalse - mathutil.BoolToNumber[NullBool](cond)
}
return NBFalse
// type check
var _ json.Marshaler = NBNull
// MarshalJSON implements the json.Marshaler interface for NullBool.
func (nb NullBool) MarshalJSON() (b []byte, err error) {
return []byte(nb.String()), nil
}
// type check

View File

@@ -10,6 +10,52 @@ import (
"github.com/stretchr/testify/require"
)
func TestNullBool_MarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want []byte
in aghalg.NullBool
}{{
name: "null",
wantErrMsg: "",
want: []byte("null"),
in: aghalg.NBNull,
}, {
name: "true",
wantErrMsg: "",
want: []byte("true"),
in: aghalg.NBTrue,
}, {
name: "false",
wantErrMsg: "",
want: []byte("false"),
in: aghalg.NBFalse,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.in.MarshalJSON()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
in := &struct {
A aghalg.NullBool
}{
A: aghalg.NBTrue,
}
got, err := json.Marshal(in)
require.NoError(t, err)
assert.Equal(t, []byte(`{"A":true}`), got)
})
}
func TestNullBool_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string

View File

@@ -0,0 +1,33 @@
// Package aghchan contains channel utilities.
package aghchan
import (
"fmt"
"time"
)
// Receive returns an error if it cannot receive a value form c before timeout
// runs out.
func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) {
var zero T
timeoutCh := time.After(timeout)
select {
case <-timeoutCh:
// TODO(a.garipov): Consider implementing [errors.Aser] for
// os.ErrTimeout.
return zero, false, fmt.Errorf("did not receive after %s", timeout)
case v, ok = <-c:
return v, ok, nil
}
}
// MustReceive panics if it cannot receive a value form c before timeout runs
// out.
func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) {
v, ok, err := Receive(c, timeout)
if err != nil {
panic(err)
}
return v, ok
}

View File

@@ -2,13 +2,27 @@
package aghhttp
import (
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
)
// HTTP scheme constants.
const (
SchemeHTTP = "http"
SchemeHTTPS = "https"
)
// RegisterFunc is the function that sets the handler to handle the URL for the
// method.
//
// TODO(e.burkov, a.garipov): Get rid of it.
type RegisterFunc func(method, url string, handler http.HandlerFunc)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {
@@ -17,8 +31,52 @@ func OK(w http.ResponseWriter) {
}
// Error writes formatted message to w and also logs it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s: %s", r.Method, r.URL, text)
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
http.Error(w, text, code)
}
// UserAgent returns the ID of the service as a User-Agent string. It can also
// be used as the value of the Server HTTP header.
func UserAgent() (ua string) {
return fmt.Sprintf("AdGuardHome/%s", version.Version())
}
// textPlainDeprMsg is the message returned to API users when they try to use
// an API that used to accept "text/plain" but doesn't anymore.
const textPlainDeprMsg = `using this api with the text/plain content-type is deprecated; ` +
`use application/json`
// WriteTextPlainDeprecated responds to the request with a message about
// deprecation and removal of a plain-text API if the request is made with the
// "text/plain" content-type.
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
return false
}
Error(r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
return true
}
// WriteJSONResponse sets the content-type header in w.Header() to
// "application/json", writes a header with a "200 OK" status, encodes resp to
// w, calls [Error] on any returned error, and returns it as well.
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) {
return WriteJSONResponseCode(w, r, http.StatusOK, resp)
}
// WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to
// redefine the status code.
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
w.WriteHeader(code)
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)
}
return err
}

View File

@@ -0,0 +1,25 @@
package aghhttp
// HTTP Headers
// HTTP header name constants.
//
// TODO(a.garipov): Remove unused.
const (
HdrNameAcceptEncoding = "Accept-Encoding"
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HdrNameAltSvc = "Alt-Svc"
HdrNameContentEncoding = "Content-Encoding"
HdrNameContentType = "Content-Type"
HdrNameOrigin = "Origin"
HdrNameServer = "Server"
HdrNameTrailer = "Trailer"
HdrNameUserAgent = "User-Agent"
HdrNameVary = "Vary"
)
// HTTP header value constants.
const (
HdrValApplicationJSON = "application/json"
HdrValTextPlain = "text/plain"
)

View File

@@ -4,6 +4,9 @@ package aghio
import (
"fmt"
"io"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
)
// LimitReachedError records the limit and the operation that caused it.
@@ -11,22 +14,22 @@ type LimitReachedError struct {
Limit int64
}
// Error implements the error interface for LimitReachedError.
// Error implements the [error] interface for *LimitReachedError.
//
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReader is a wrapper for io.Reader with limited reader and dealing with
// errors package.
// limitedReader is a wrapper for [io.Reader] limiting the input and dealing
// with errors package.
type limitedReader struct {
r io.Reader
limit int64
n int64
}
// Read implements Reader interface.
// Read implements the [io.Reader] interface.
func (lr *limitedReader) Read(p []byte) (n int, err error) {
if lr.n == 0 {
return 0, &LimitReachedError{
@@ -34,9 +37,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
}
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
p = p[:mathutil.Min(lr.n, int64(len(p)))]
n, err = lr.r.Read(p)
lr.n -= int64(n)
@@ -48,7 +49,7 @@ func (lr *limitedReader) Read(p []byte) (n int, err error) {
// n bytes read.
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n)
return nil, errors.Error("limit must be non-negative")
}
return &limitedReader{

View File

@@ -24,7 +24,7 @@ func TestLimitReader(t *testing.T) {
name: "zero",
n: 0,
}, {
wantErrMsg: "aghio: invalid n in LimitReader: -1",
wantErrMsg: "limit must be non-negative",
name: "negative",
n: -1,
}}

View File

@@ -5,10 +5,11 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
)
// ARPDB: The Network Neighborhood Database
@@ -54,7 +55,7 @@ type Neighbor struct {
Name string
// IP contains either IPv4 or IPv6.
IP net.IP
IP netip.Addr
// MAC contains the hardware address.
MAC net.HardwareAddr
@@ -64,8 +65,8 @@ type Neighbor struct {
func (n Neighbor) Clone() (clone Neighbor) {
return Neighbor{
Name: n.Name,
IP: netutil.CloneIP(n.IP),
MAC: netutil.CloneMAC(n.MAC),
IP: n.IP,
MAC: slices.Clone(n.MAC),
}
}

View File

@@ -1,11 +1,11 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
@@ -33,8 +33,7 @@ func newARPDB() (arp *cmdARPDB) {
// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD.
// The expected input format:
//
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
//
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
@@ -49,22 +48,28 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
mac, err := net.ParseMAC(hwStr)
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
n.MAC = mac
}
host := fields[0]
if err := netutil.ValidateDomainName(host); err != nil {
log.Debug("parsing arp output: %s", err)
err = netutil.ValidateDomainName(host)
if err != nil {
log.Debug("arpdb: parsing arp output: host: %s", err)
} else {
n.Name = host
}

View File

@@ -1,10 +1,10 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -18,14 +18,14 @@ hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [
var wantNeighs = []Neighbor{{
Name: "hostname.one",
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
IP: netip.MustParseAddr("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}}

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet
@@ -8,6 +7,7 @@ import (
"fmt"
"io/fs"
"net"
"net/netip"
"strings"
"sync"
@@ -95,7 +95,8 @@ func (arp *fsysARPDB) Refresh() (err error) {
}
n := Neighbor{}
if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() {
n.IP, err = netip.ParseAddr(fields[0])
if err != nil || n.IP.IsUnspecified() {
continue
} else if n.MAC, err = net.ParseMAC(fields[3]); err != nil {
continue
@@ -117,9 +118,8 @@ func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The
// expected input format:
//
// IP address HW type Flags HW address Mask Device
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
//
// IP address HW type Flags HW address Mask Device
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if !sc.Scan() {
// Skip the header.
@@ -137,15 +137,19 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() {
ip, err := netip.ParseAddr(fields[0])
if err != nil || n.IP.IsUnspecified() {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(hwStr)
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
@@ -161,8 +165,7 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
// parseArpA parses the output of the "arp -a -n" command on Linux. The
// expected input format:
//
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
//
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
@@ -177,7 +180,9 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
@@ -185,7 +190,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
@@ -194,7 +199,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
host := fields[0]
if verr := netutil.ValidateDomainName(host); verr != nil {
log.Debug("parsing arp output: %s", verr)
log.Debug("arpdb: parsing arp output: host: %s", verr)
} else {
n.Name = host
}
@@ -208,8 +213,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
// parseIPNeigh parses the output of the "ip neigh" command on Linux. The
// expected input format:
//
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
//
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
@@ -222,14 +226,18 @@ func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[4]); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(fields[4])
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {

View File

@@ -1,10 +1,10 @@
//go:build linux
// +build linux
package aghnet
import (
"net"
"net/netip"
"sync"
"testing"
"testing/fstest"
@@ -34,10 +34,10 @@ const ipNeighOutput = `
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -1,11 +1,11 @@
//go:build openbsd
// +build openbsd
package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
@@ -32,9 +32,8 @@ func newARPDB() (arp *cmdARPDB) {
// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The
// expected input format:
//
// Host Ethernet Address Netif Expire Flags
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
//
// Host Ethernet Address Netif Expire Flags
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
// Skip the header.
if !sc.Scan() {
@@ -52,14 +51,18 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(fields[1])
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {

View File

@@ -1,10 +1,10 @@
//go:build openbsd
// +build openbsd
package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -16,9 +16,9 @@ Host Ethernet Address Netif Expire Flags
`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -2,6 +2,7 @@ package aghnet
import (
"net"
"net/netip"
"sync"
"testing"
@@ -35,7 +36,7 @@ func (arp *TestARPDB) Neighbors() (ns []Neighbor) {
}
func TestARPDBS(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
knownIP := netip.MustParseAddr("1.2.3.4")
knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}
succRefrCount, failRefrCount := 0, 0

View File

@@ -1,11 +1,11 @@
//go:build windows
// +build windows
package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
)
@@ -25,12 +25,10 @@ func newARPDB() (arp *cmdARPDB) {
// parseArpA parses the output of the "arp /a" command on Windows. The expected
// input format (the first line is empty):
//
//
// Interface: 192.168.56.16 --- 0x7
// Internet Address Physical Address Type
// 192.168.56.1 0a-00-27-00-00-00 dynamic
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
//
// Interface: 192.168.56.16 --- 0x7
// Internet Address Physical Address Type
// 192.168.56.1 0a-00-27-00-00-00 dynamic
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
ns = make([]Neighbor, 0, lenHint)
for sc.Scan() {
@@ -46,13 +44,15 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
mac, err := net.ParseMAC(fields[1])
if err != nil {
continue
} else {
n.MAC = mac

View File

@@ -1,10 +1,10 @@
//go:build windows
// +build windows
package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -15,9 +15,9 @@ Interface: 192.168.1.1 --- 0x7
::ffff:ffff ef-cd-ab-ef-cd-ab static`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package aghnet
@@ -7,6 +6,7 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"os"
"time"
@@ -39,48 +39,44 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
}
// ifaceIPv4Subnet returns the first suitable IPv4 subnetwork iface has.
func ifaceIPv4Subnet(iface *net.Interface) (subnet *net.IPNet, err error) {
func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) {
var addrs []net.Addr
if addrs, err = iface.Addrs(); err != nil {
return nil, err
return netip.Prefix{}, err
}
for _, a := range addrs {
var ip net.IP
var maskLen int
switch a := a.(type) {
case *net.IPAddr:
subnet = &net.IPNet{
IP: a.IP,
Mask: a.IP.DefaultMask(),
}
ip = a.IP
maskLen, _ = ip.DefaultMask().Size()
case *net.IPNet:
subnet = a
ip = a.IP
maskLen, _ = a.Mask.Size()
default:
continue
}
if ip4 := subnet.IP.To4(); ip4 != nil {
subnet.IP = ip4
return subnet, nil
if ip = ip.To4(); ip != nil {
return netip.PrefixFrom(netip.AddrFrom4(*(*[4]byte)(ip)), maskLen), nil
}
}
return nil, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
return netip.Prefix{}, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
}
// checkOtherDHCPv4 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
var subnet *net.IPNet
var subnet netip.Prefix
if subnet, err = ifaceIPv4Subnet(iface); err != nil {
return false, err
}
// Resolve broadcast addr.
dst := netutil.IPPort{
IP: BroadcastFromIPNet(subnet),
Port: 67,
}.String()
dst := netip.AddrPortFrom(BroadcastFromPref(subnet), 67).String()
var dstAddr *net.UDPAddr
if dstAddr, err = net.ResolveUDPAddr("udp4", dst); err != nil {
return false, fmt.Errorf("couldn't resolve UDP address %s: %w", dst, err)

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -45,11 +45,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
// GenerateHostname generates the hostname from ip. In case of using IPv4 the
// result should be like:
//
// 192-168-10-1
// 192-168-10-1
//
// In case of using IPv6, the result is like:
//
// ff80-f076-0000-0000-0000-0000-0000-0010
// ff80-f076-0000-0000-0000-0000-0000-0010
//
// ip must be either an IPv4 or an IPv6.
func GenerateHostname(ip net.IP) (hostname string) {

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"io"
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync"
@@ -19,6 +19,7 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
)
// DefaultHostsPaths returns the slice of paths default for the operating system
@@ -55,7 +56,7 @@ func (rm *requestMatcher) MatchRequest(
) (res *urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
log.Debug("%s: handling the request", hostsContainerPref)
log.Debug("%s: handling the request for %s", hostsContainerPref, req.Hostname)
default:
return nil, false
}
@@ -70,8 +71,7 @@ func (rm *requestMatcher) MatchRequest(
// rule or an empty string if the last doesn't exist. The returned rules are in
// a processed format like:
//
// ip host1 host2 ...
//
// ip host1 host2 ...
func (rm *requestMatcher) Translate(rule string) (hostRule string) {
rm.stateLock.RLock()
defer rm.stateLock.RUnlock()
@@ -107,10 +107,10 @@ type HostsContainer struct {
done chan struct{}
// updates is the channel for receiving updated hosts.
updates chan *netutil.IPMap
updates chan HostsRecords
// last is the set of hosts that was cached within last detected change.
last *netutil.IPMap
last HostsRecords
// fsys is the working file system to read hosts files from.
fsys fs.FS
@@ -125,6 +125,27 @@ type HostsContainer struct {
listID int
}
// HostsRecords is a mapping of an IP address to its hosts data.
type HostsRecords map[netip.Addr]*HostsRecord
// HostsRecord represents a single hosts file record.
type HostsRecord struct {
Aliases *stringutil.Set
Canonical string
}
// equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
} else if other == nil {
return false
}
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
}
// ErrNoHostsPaths is returned when there are no valid paths to watch passed to
// the HostsContainer.
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
@@ -159,7 +180,7 @@ func NewHostsContainer(
},
listID: listID,
done: make(chan struct{}, 1),
updates: make(chan *netutil.IPMap, 1),
updates: make(chan HostsRecords, 1),
fsys: fsys,
w: w,
patterns: patterns,
@@ -197,9 +218,8 @@ func (hc *HostsContainer) Close() (err error) {
return nil
}
// Upd returns the channel into which the updates are sent. The receivable
// map's values are guaranteed to be of type of *stringutil.Set.
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
// Upd returns the channel into which the updates are sent.
func (hc *HostsContainer) Upd() (updates <-chan HostsRecords) {
return hc.updates
}
@@ -265,7 +285,7 @@ type hostsParser struct {
// table stores only the unique IP-hostname pairs. It's also sent to the
// updates channel afterwards.
table *netutil.IPMap
table HostsRecords
}
// newHostsParser creates a new *hostsParser with buffers of size taken from the
@@ -274,7 +294,7 @@ func (hc *HostsContainer) newHostsParser() (hp *hostsParser) {
return &hostsParser{
rulesBuilder: &strings.Builder{},
translations: map[string]string{},
table: netutil.NewIPMap(hc.last.Len()),
table: make(HostsRecords, len(hc.last)),
}
}
@@ -286,25 +306,26 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
s := bufio.NewScanner(r)
for s.Scan() {
ip, hosts := hp.parseLine(s.Text())
if ip == nil || len(hosts) == 0 {
if ip == (netip.Addr{}) || len(hosts) == 0 {
continue
}
hp.addPairs(ip, hosts)
hp.addRecord(ip, hosts)
}
return nil, true, s.Err()
}
// parseLine parses the line having the hosts syntax ignoring invalid ones.
func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) {
fields := strings.Fields(line)
if len(fields) < 2 {
return nil, nil
return netip.Addr{}, nil
}
if ip = net.ParseIP(fields[0]); ip == nil {
return nil, nil
ip, err := netip.ParseAddr(fields[0])
if err != nil {
return netip.Addr{}, nil
}
for _, f := range fields[1:] {
@@ -322,7 +343,7 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
// See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
//
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
err := netutil.ValidateDomainName(f)
err = netutil.ValidateDomainName(f)
if err != nil {
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
@@ -335,43 +356,47 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
return ip, hosts
}
// addPair puts the pair of ip and host to the rules builder if needed. For
// each ip the first member of hosts will become the main one.
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
v, ok := hp.table.Get(ip)
// addRecord puts the record for the IP address to the rules builder if needed.
// The first host is considered to be the canonical name for the IP address.
// hosts must have at least one name.
func (hp *hostsParser) addRecord(ip netip.Addr, hosts []string) {
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
rec, ok := hp.table[ip]
if !ok {
// This ip is added at the first time.
v = stringutil.NewSet()
hp.table.Set(ip, v)
rec = &HostsRecord{
Aliases: stringutil.NewSet(),
}
rec.Canonical, hosts = hosts[0], hosts[1:]
hp.addRules(ip, rec.Canonical, line)
hp.table[ip] = rec
}
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
if !ok {
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
return
}
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
for _, h := range hosts {
if set.Has(h) {
for _, host := range hosts {
if rec.Canonical == host || rec.Aliases.Has(host) {
continue
}
set.Add(h)
rec.Aliases.Add(host)
rule, rulePtr := hp.writeRules(h, ip)
hp.translations[rule], hp.translations[rulePtr] = processed, processed
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
hp.addRules(ip, host, line)
}
}
// addRules adds rules and rule translations for the line.
func (hp *hostsParser) addRules(ip netip.Addr, host, line string) {
rule, rulePtr := hp.writeRules(host, ip)
hp.translations[rule], hp.translations[rulePtr] = line, line
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
}
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
// pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip)
func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr string) {
// TODO(a.garipov): Add a netip.Addr version to netutil.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil {
return "", ""
}
@@ -389,7 +414,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
var qtype string
// The validation of the IP address has been performed earlier so it is
// guaranteed to be either an IPv4 or an IPv6.
if ip.To4() != nil {
if ip.Is4() {
qtype = "A"
} else {
qtype = "AAAA"
@@ -416,37 +441,8 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
return rule, rulePtr
}
// equalSet returns true if the internal hosts table just parsed equals target.
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
if target == nil {
// hp.table shouldn't appear nil since it's initialized on each refresh.
return target == hp.table
}
if hp.table.Len() != target.Len() {
return false
}
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
// ok is set to true if the target doesn't contain ip or if the
// appropriate hosts set isn't equal to the checked one.
if a, hasIP := target.Get(ip); !hasIP {
ok = true
} else if hosts, aok := a.(*stringutil.Set); aok {
ok = !hosts.Equal(b.(*stringutil.Set))
}
// Continue only if maps has no discrepancies.
return !ok
})
// Return true if every value from the IP map has no discrepancies with the
// appropriate one from the target.
return !ok
}
// sendUpd tries to send the parsed data to the ch.
func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) {
func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
log.Debug("%s: sending upd", hostsContainerPref)
upd := hp.table
@@ -484,14 +480,15 @@ func (hc *HostsContainer) refresh() (err error) {
return fmt.Errorf("refreshing : %w", err)
}
if hp.equalSet(hc.last) {
// hc.last is nil on the first refresh, so let that one through.
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
log.Debug("%s: no changes detected", hostsContainerPref)
return nil
}
defer hp.sendUpd(hc.updates)
hc.last = hp.table.ShallowClone()
hc.last = maps.Clone(hp.table)
var rulesStrg *filterlist.RuleStorage
if rulesStrg, err = hp.newStrg(hc.listID); err != nil {

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !(windows || linux)
// +build !windows,!linux
package aghnet

View File

@@ -3,6 +3,7 @@ package aghnet
import (
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync/atomic"
@@ -10,8 +11,10 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter"
@@ -134,7 +137,7 @@ func TestNewHostsContainer(t *testing.T) {
func TestHostsContainer_refresh(t *testing.T) {
// TODO(e.burkov): Test the case with no actual updates.
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}}
@@ -159,31 +162,37 @@ func TestHostsContainer_refresh(t *testing.T) {
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
upd, ok := <-hc.Upd()
checkRefresh := func(t *testing.T, want *HostsRecord) {
t.Helper()
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
require.True(t, ok)
require.NotNil(t, upd)
assert.Equal(t, 1, upd.Len())
assert.Len(t, upd, 1)
v, ok := upd.Get(ip)
rec, ok := upd[ip]
require.True(t, ok)
require.NotNil(t, rec)
var set *stringutil.Set
set, ok = v.(*stringutil.Set)
require.True(t, ok)
assert.True(t, set.Equal(wantHosts))
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, stringutil.NewSet("hostname"))
checkRefresh(t, &HostsRecord{
Aliases: stringutil.NewSet(),
Canonical: "hostname",
})
})
t.Run("second_refresh", func(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
checkRefresh(t, &HostsRecord{
Aliases: stringutil.NewSet("alias"),
Canonical: "hostname",
})
})
t.Run("double_refresh", func(t *testing.T) {
@@ -363,10 +372,15 @@ func TestHostsContainer(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
testCases := []struct {
want []*rules.DNSRewrite
name string
req *urlfilter.DNSRequest
name string
want []*rules.DNSRewrite
}{{
req: &urlfilter.DNSRequest{
Hostname: "simplehost",
DNSType: dns.TypeA,
},
name: "simple",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 1),
@@ -376,27 +390,12 @@ func TestHostsContainer(t *testing.T) {
Value: net.ParseIP("::1"),
RRType: dns.TypeAAAA,
}},
name: "simple",
req: &urlfilter.DNSRequest{
Hostname: "simplehost",
DNSType: dns.TypeA,
},
}, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "hello_alias",
req: &urlfilter.DNSRequest{
Hostname: "hello.world",
DNSType: dns.TypeA,
},
}, {
name: "hello_alias",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
@@ -406,26 +405,41 @@ func TestHostsContainer(t *testing.T) {
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
name: "other_line_alias",
}, {
req: &urlfilter.DNSRequest{
Hostname: "hello.world.again",
DNSType: dns.TypeA,
},
name: "other_line_alias",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 0),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.ParseIP("::"),
RRType: dns.TypeAAAA,
}},
}, {
want: []*rules.DNSRewrite{},
name: "hello_subdomain",
req: &urlfilter.DNSRequest{
Hostname: "say.hello",
DNSType: dns.TypeA,
},
}, {
name: "hello_subdomain",
want: []*rules.DNSRewrite{},
name: "hello_alias_subdomain",
}, {
req: &urlfilter.DNSRequest{
Hostname: "say.hello.world",
DNSType: dns.TypeA,
},
name: "hello_alias_subdomain",
want: []*rules.DNSRewrite{},
}, {
req: &urlfilter.DNSRequest{
Hostname: "for.testing",
DNSType: dns.TypeA,
},
name: "lots_of_aliases",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -435,37 +449,37 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::2"),
}},
name: "lots_of_aliases",
req: &urlfilter.DNSRequest{
Hostname: "for.testing",
DNSType: dns.TypeA,
},
}, {
req: &urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR,
},
name: "reverse",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypePTR,
Value: "simplehost.",
}},
name: "reverse",
req: &urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR,
},
}, {
want: []*rules.DNSRewrite{},
name: "non-existing",
req: &urlfilter.DNSRequest{
Hostname: "nonexisting",
Hostname: "nonexistent.example",
DNSType: dns.TypeA,
},
name: "non-existing",
want: []*rules.DNSRewrite{},
}, {
want: nil,
name: "bad_type",
req: &urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypeSRV,
},
name: "bad_type",
want: nil,
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain",
DNSType: dns.TypeA,
},
name: "issue_4216_4_6",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -475,12 +489,12 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::42"),
}},
name: "issue_4216_4_6",
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain",
Hostname: "domain4",
DNSType: dns.TypeA,
},
}, {
name: "issue_4216_4",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeA,
@@ -490,12 +504,12 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeA,
Value: net.IPv4(1, 3, 5, 7),
}},
name: "issue_4216_4",
req: &urlfilter.DNSRequest{
Hostname: "domain4",
DNSType: dns.TypeA,
},
}, {
req: &urlfilter.DNSRequest{
Hostname: "domain6",
DNSType: dns.TypeAAAA,
},
name: "issue_4216_6",
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypeAAAA,
@@ -505,11 +519,6 @@ func TestHostsContainer(t *testing.T) {
RRType: dns.TypeAAAA,
Value: net.ParseIP("::31"),
}},
name: "issue_4216_6",
req: &urlfilter.DNSRequest{
Hostname: "domain6",
DNSType: dns.TypeAAAA,
},
}}
stubWatcher := aghtest.FSWatcher{
@@ -551,13 +560,13 @@ func TestHostsContainer(t *testing.T) {
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP net.IP
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
@@ -572,7 +581,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "invalid_line",
line: ipStr,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
@@ -587,7 +596,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
@@ -597,7 +606,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "empty",
line: ``,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}}
@@ -605,7 +614,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.True(t, tc.wantIP.Equal(got))
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet
@@ -16,7 +15,7 @@ import (
func defaultHostsPaths() (paths []string) {
sysDir, err := windows.GetSystemDirectory()
if err != nil {
log.Error("getting system directory: %s", err)
log.Error("aghnet: getting system directory: %s", err)
return []string{}
}

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd netbsd openbsd solaris
//go:build darwin || freebsd || openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -18,7 +18,7 @@ type IpsetManager interface {
//
// The syntax of the ipsetConf is:
//
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
//
// If ipsetConf is empty, msg and err are nil. The error is of type
// *aghos.UnsupportedError if the OS is not supported.

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet
@@ -19,27 +18,18 @@ import (
// How to test on a real Linux machine:
//
// 1. Run:
// 1. Run "sudo ipset create example_set hash:ip family ipv4".
//
// sudo ipset create example_set hash:ip family ipv4
// 2. Run "sudo ipset list example_set". The Members field should be empty.
//
// 2. Run:
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
//
// sudo ipset list example_set
// 4. Start AdGuardHome.
//
// The Members field should be empty.
// 5. Make requests to example.com and its subdomains.
//
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
//
// 4. Start AdGuardHome.
//
// 5. Make requests to example.com and its subdomains.
//
// 6. Run:
//
// sudo ipset list example_set
//
// The Members field should contain the resolved IP addresses.
// 6. Run "sudo ipset list example_set". The Members field should contain the
// resolved IP addresses.
// newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !linux
// +build !linux
package aghnet

View File

@@ -7,12 +7,12 @@ import (
"fmt"
"io"
"net"
"net/netip"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// Variables and functions to substitute in tests.
@@ -47,26 +47,31 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
//
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed.
func GatewayIP(ifaceName string) (ip net.IP) {
func GatewayIP(ifaceName string) (ip netip.Addr) {
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
if err != nil {
log.Debug("%s", err)
return nil
return netip.Addr{}
} else if code != 0 {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
return nil
return netip.Addr{}
}
fields := bytes.Fields(out)
// The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || string(fields[0]) != "default" {
return nil
return netip.Addr{}
}
return net.ParseIP(string(fields[2]))
ip, err = netip.ParseAddr(string(fields[2]))
if err != nil {
return netip.Addr{}
}
return ip
}
// CanBindPrivilegedPorts checks if current process can bind to privileged
@@ -78,9 +83,9 @@ func CanBindPrivilegedPorts() (can bool, err error) {
// NetInterface represents an entry of network interfaces map.
type NetInterface struct {
// Addresses are the network interface addresses.
Addresses []net.IP `json:"ip_addresses,omitempty"`
Addresses []netip.Addr `json:"ip_addresses,omitempty"`
// Subnets are the IP networks for this network interface.
Subnets []*net.IPNet `json:"-"`
Subnets []netip.Prefix `json:"-"`
Name string `json:"name"`
HardwareAddr net.HardwareAddr `json:"hardware_address"`
Flags net.Flags `json:"flags"`
@@ -101,63 +106,88 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
})
}
func NetInterfaceFrom(iface *net.Interface) (niface *NetInterface, err error) {
niface = &NetInterface{
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
MTU: iface.MTU,
}
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
// Collect network interface addresses.
for _, addr := range addrs {
n, ok := addr.(*net.IPNet)
if !ok {
// Should be *net.IPNet, this is weird.
return nil, fmt.Errorf("expected %[2]s to be %[1]T, got %[2]T", n, addr)
} else if ip4 := n.IP.To4(); ip4 != nil {
n.IP = ip4
}
ip, ok := netip.AddrFromSlice(n.IP)
if !ok {
return nil, fmt.Errorf("bad address %s", n.IP)
}
ip = ip.Unmap()
if ip.IsLinkLocalUnicast() {
// Ignore link-local IPv4.
if ip.Is4() {
continue
}
ip = ip.WithZone(iface.Name)
}
ones, _ := n.Mask.Size()
p := netip.PrefixFrom(ip, ones)
niface.Addresses = append(niface.Addresses, ip)
niface.Subnets = append(niface.Subnets, p)
}
return niface, nil
}
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// WEB only we do not return link-local addresses here.
//
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
func GetValidNetInterfacesForWeb() (nifaces []*NetInterface, err error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
return nil, fmt.Errorf("getting interfaces: %w", err)
} else if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface")
return nil, errors.Error("no legible interfaces")
}
for _, iface := range ifaces {
var addrs []net.Addr
addrs, err = iface.Addrs()
for i := range ifaces {
var niface *NetInterface
niface, err = NetInterfaceFrom(&ifaces[i])
if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
netIface := &NetInterface{
MTU: iface.MTU,
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
}
// Collect network interface addresses.
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
// Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
}
// Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 {
netIfaces = append(netIfaces, netIface)
return nil, err
} else if len(niface.Addresses) != 0 {
// Discard interfaces with no addresses.
nifaces = append(nifaces, niface)
}
}
return netIfaces, nil
return nifaces, nil
}
// GetInterfaceByIP returns the name of interface containing provided ip.
// InterfaceByIP returns the name of the interface bound to ip.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetInterfaceByIP(ip net.IP) string {
// TODO(a.garipov, e.burkov): This function is technically incorrect, since one
// IP address can be shared by multiple interfaces in some configurations.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func InterfaceByIP(ip netip.Addr) (ifaceName string) {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
return ""
@@ -165,7 +195,7 @@ func GetInterfaceByIP(ip net.IP) string {
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
if ip.Equal(addr) {
if ip == addr {
return iface.Name
}
}
@@ -174,15 +204,16 @@ func GetInterfaceByIP(ip net.IP) string {
return ""
}
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// GetSubnet returns the subnet corresponding to the interface of zero prefix if
// the search fails.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet {
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) (p netip.Prefix) {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
return nil
return p
}
for _, netIface := range netIfaces {
@@ -191,14 +222,14 @@ func GetSubnet(ifaceName string) *net.IPNet {
}
}
return nil
return p
}
// CheckPort checks if the port is available for binding. network is expected
// to be one of "udp" and "tcp".
func CheckPort(network string, ip net.IP, port int) (err error) {
func CheckPort(network string, ipp netip.AddrPort) (err error) {
var c io.Closer
addr := netutil.IPPort{IP: ip, Port: port}.String()
addr := ipp.String()
switch network {
case "tcp":
c, err = net.Listen(network, addr)
@@ -248,18 +279,23 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
return addrs, nil
}
// BroadcastFromIPNet calculates the broadcast IP address for n.
func BroadcastFromIPNet(n *net.IPNet) (dc net.IP) {
dc = netutil.CloneIP(n.IP)
mask := n.Mask
if mask == nil {
mask = dc.DefaultMask()
// BroadcastFromPref calculates the broadcast IP address for p.
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
bc = p.Addr().Unmap()
if !bc.IsValid() {
return netip.Addr{}
}
for i, b := range mask {
dc[i] |= ^b
maskLen, addrLen := p.Bits(), bc.BitLen()
if maskLen == addrLen {
return bc
}
return dc
ipBytes := bc.AsSlice()
for i := maskLen; i < addrLen; i++ {
ipBytes[i/8] |= 1 << (7 - (i % 8))
}
bc, _ = netip.AddrFromSlice(ipBytes)
return bc
}

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd || openbsd
// +build darwin freebsd openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build darwin
// +build darwin
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet
@@ -7,12 +6,13 @@ import (
"bufio"
"fmt"
"io"
"net"
"net/netip"
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe"
"golang.org/x/sys/unix"
@@ -22,17 +22,27 @@ import (
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(
res, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_RAISE,
unix.CAP_NET_BIND_SERVICE,
0,
0,
)
if err != nil {
if errors.Is(err, unix.EINVAL) {
// Older versions of Linux kernel do not support this. Print a
// warning and check admin rights.
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
} else {
return false, err
}
}
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
return res == 1 || adm, nil
}
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
@@ -141,7 +151,7 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
// interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil {
if !ipNet.Addr().IsValid() {
return errors.Error("can't get IP address")
}
@@ -164,7 +174,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
func dhcpcdConfIface(ifaceName string, subnet netip.Prefix, gateway netip.Addr) (conf string) {
b := &strings.Builder{}
stringutil.WriteToBuilder(
b,
@@ -173,15 +183,15 @@ func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf stri
" added by AdGuard Home.\ninterface ",
ifaceName,
"\nstatic ip_address=",
ipNet.String(),
subnet.String(),
"\n",
)
if gwIP != nil {
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
if gateway != (netip.Addr{}) {
stringutil.WriteToBuilder(b, "static routers=", gateway.String(), "\n")
}
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
stringutil.WriteToBuilder(b, "static domain_name_servers=", subnet.Addr().String(), "\n\n")
return b.String()
}

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -6,11 +6,11 @@ import (
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
@@ -19,7 +19,7 @@ import (
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
testutil.DiscardLogOutput(m)
}
// testdata is the filesystem containing data for testing the package.
@@ -93,34 +93,29 @@ func TestGatewayIP(t *testing.T) {
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
name string
shell mapShell
want net.IP
want netip.Addr
name string
}{{
name: "success_v4",
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
name: "success_v6",
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0xFF, 0xFF,
},
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
name: "bad_output",
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil,
want: netip.Addr{},
name: "bad_output",
}, {
name: "err_runcmd",
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil,
want: netip.Addr{},
name: "err_runcmd",
}, {
name: "bad_code",
shell: theOnlyCmd(cmd, 1, "", nil),
want: nil,
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
@@ -132,7 +127,7 @@ func TestGatewayIP(t *testing.T) {
}
}
func TestGetInterfaceByIP(t *testing.T) {
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
@@ -142,7 +137,7 @@ func TestGetInterfaceByIP(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := GetInterfaceByIP(ip)
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
@@ -150,65 +145,61 @@ func TestGetInterfaceByIP(t *testing.T) {
}
func TestBroadcastFromIPNet(t *testing.T) {
known6 := net.IP{
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
}
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
name string
subnet *net.IPNet
want net.IP
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
subnet: &net.IPNet{
IP: net.IP{192, 168, 0, 1},
Mask: net.IPMask{255, 255, 15, 0},
},
want: net.IP{192, 168, 240, 255},
}, {
name: "ipv6_no_mask",
subnet: &net.IPNet{
IP: known6,
},
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
subnet: &net.IPNet{
IP: net.IP{192, 168, 1, 2},
},
want: net.IP{192, 168, 1, 255},
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
subnet: &net.IPNet{
IP: net.IP{0, 0, 0, 0},
Mask: net.IPMask{0, 0, 0, 0},
},
want: net.IPv4bcast,
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
bc := BroadcastFromIPNet(tc.subnet)
assert.True(t, bc.Equal(tc.want), bc)
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:")
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := netutil.IPPortFromAddr(l.Addr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp.IP, ipp.Port)
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
@@ -216,16 +207,15 @@ func TestCheckPort(t *testing.T) {
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", "127.0.0.1:")
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp.IP, ipp.Port)
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
@@ -233,12 +223,12 @@ func TestCheckPort(t *testing.T) {
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", nil, 0)
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
@@ -322,18 +312,18 @@ func TestNetInterface_MarshalJSON(t *testing.T) {
`"mtu":1500` +
`}` + "\n"
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []net.IP{ip4, ip6},
Subnets: []*net.IPNet{{
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,

View File

@@ -1,5 +1,4 @@
//go:build openbsd || freebsd || linux || darwin
// +build openbsd freebsd linux darwin
//go:build darwin || freebsd || linux || openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -19,7 +19,7 @@ type SystemResolvers interface {
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
hostGenFunc HostGenFunc,

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package aghnet
@@ -7,6 +6,7 @@ import (
"context"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -18,9 +18,8 @@ func createTestSystemResolversImpl(
t.Helper()
sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)
return sr.(*systemResolvers)
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
}
func TestSystemResolvers_Refresh(t *testing.T) {

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet
@@ -62,9 +61,8 @@ func writeExit(w io.WriteCloser) {
// scanAddrs scans the DNS addresses from nslookup's output. The expected
// output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
func scanAddrs(s *bufio.Scanner) (addrs []string) {
for s.Scan() {
line := strings.TrimSpace(s.Text())

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,11 +1,11 @@
package aghos
package aghos_test
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/testutil"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
testutil.DiscardLogOutput(m)
}

View File

@@ -1,5 +1,4 @@
//go:build mips || mips64
// +build mips mips64
// This file is an adapted version of github.com/josharian/native.

View File

@@ -1,5 +1,4 @@
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
// +build amd64 386 arm arm64 mipsle mips64le ppc64le
// This file is an adapted version of github.com/josharian/native.

View File

@@ -0,0 +1,57 @@
package aghos
import (
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// errFS is an fs.FS implementation, method Open of which always returns
// errFSOpen.
type errFS struct{}
// errFSOpen is returned from errFS.Open.
const errFSOpen errors.Error = "test open error"
// Open implements the fs.FS interface for *errFS. fsys is always nil and err
// is always errFSOpen.
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
return nil, errFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@@ -1,13 +1,13 @@
package aghos
package aghos_test
import (
"bufio"
"io"
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -16,7 +16,7 @@ import (
func TestFileWalker_Walk(t *testing.T) {
const attribute = `000`
makeFileWalker := func(_ string) (fw FileWalker) {
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
@@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
patterns = append(patterns, s.Text())
@@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr
}).Walk(f, "*")
require.ErrorIs(t, err, rerr)
@@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
assert.False(t, ok)
})
}
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@@ -17,6 +17,8 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"golang.org/x/exp/slices"
)
// UnsupportedError is returned by functions and methods when a particular
@@ -60,9 +62,8 @@ const MaxCmdOutputSize = 64 * 1024
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
if len(out) > MaxCmdOutputSize {
out = out[:MaxCmdOutputSize]
}
out = out[:mathutil.Min(len(out), MaxCmdOutputSize)]
if err != nil {
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
@@ -121,13 +122,12 @@ func PIDByCommand(command string, except ...int) (pid int, err error) {
}
// parsePSOutput scans the output of ps searching the largest PID of the process
// associated with cmdName ignoring PIDs from ignore. A valid line from
// r should look like these:
//
// 123 ./example-cmd
// 1230 some/base/path/example-cmd
// 3210 example-cmd
// associated with cmdName ignoring PIDs from ignore. A valid line from r
// should look like these:
//
// 123 ./example-cmd
// 1230 some/base/path/example-cmd
// 3210 example-cmd
func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum int, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
@@ -137,14 +137,12 @@ func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum
}
cur, aerr := strconv.Atoi(fields[0])
if aerr != nil || cur < 0 || intIn(cur, ignore) {
if aerr != nil || cur < 0 || slices.Contains(ignore, cur) {
continue
}
instNum++
if cur > largest {
largest = cur
}
largest = mathutil.Max(largest, cur)
}
if err = s.Err(); err != nil {
return 0, 0, fmt.Errorf("scanning stdout: %w", err)
@@ -153,27 +151,21 @@ func parsePSOutput(r io.Reader, cmdName string, ignore []int) (largest, instNum
return largest, instNum, nil
}
// intIn returns true if nums contains n.
func intIn(n int, nums []int) (ok bool) {
for _, nn := range nums {
if n == nn {
return true
}
}
return false
}
// IsOpenWrt returns true if host OS is OpenWrt.
func IsOpenWrt() (ok bool) {
return isOpenWrt()
}
// RootDirFS returns the fs.FS rooted at the operating system's root.
// RootDirFS returns the [fs.FS] rooted at the operating system's root. On
// Windows it returns the fs.FS rooted at the volume of the system directory
// (usually, C:).
func RootDirFS() (fsys fs.FS) {
// Use empty string since os.DirFS implicitly prepends a slash to it. This
// behavior is undocumented but it currently works.
return os.DirFS("")
return rootDirFS()
}
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c)
}
// NotifyShutdownSignal notifies c on receiving shutdown signals.
@@ -181,6 +173,11 @@ func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig)
}
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)

View File

@@ -1,5 +1,4 @@
//go:build darwin || netbsd || openbsd
// +build darwin netbsd openbsd
//go:build darwin || openbsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghos

View File

@@ -1,19 +1,31 @@
//go:build darwin || freebsd || linux || openbsd
// +build darwin freebsd linux openbsd
package aghos
import (
"io/fs"
"os"
"os/signal"
"golang.org/x/sys/unix"
)
func rootDirFS() (fsys fs.FS) {
return os.DirFS("/")
}
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case

View File

@@ -1,16 +1,31 @@
//go:build windows
// +build windows
package aghos
import (
"io/fs"
"os"
"os/signal"
"path/filepath"
"syscall"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/windows"
)
func rootDirFS() (fsys fs.FS) {
// TODO(a.garipov): Use a better way if golang/go#44279 is ever resolved.
sysDir, err := windows.GetSystemDirectory()
if err != nil {
log.Error("aghos: getting root filesystem: %s; using C:", err)
// Assume that C: is the safe default.
return os.DirFS("C:")
}
return os.DirFS(filepath.VolumeName(sysDir))
}
func setRlimit(val uint64) (err error) {
return Unsupported("setrlimit")
}
@@ -40,17 +55,23 @@ func isOpenWrt() (ok bool) {
return false
}
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
os.Interrupt,
syscall.SIGTERM:
case os.Interrupt, syscall.SIGTERM:
return true
default:
return false

View File

@@ -1,5 +1,4 @@
//go:build !(windows || plan9)
// +build !windows,!plan9
//go:build !windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build windows || plan9
// +build windows plan9
//go:build windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd || linux || netbsd || openbsd
// +build darwin freebsd linux netbsd openbsd
//go:build darwin || freebsd || linux || openbsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghos

View File

@@ -3,21 +3,11 @@ package aghtest
import (
"io"
"os"
"testing"
"github.com/AdguardTeam/golibs/log"
)
// DiscardLogOutput runs tests with discarded logger output.
func DiscardLogOutput(m *testing.M) {
// TODO(e.burkov): Refactor code and tests to not use the global mutable
// logger.
log.SetOutput(io.Discard)
os.Exit(m.Run())
}
// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
// revert changes.
func ReplaceLogWriter(t testing.TB, w io.Writer) {

View File

@@ -1,20 +0,0 @@
package aghtest
import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Exchanger is a mock aghnet.Exchanger implementation for tests.
type Exchanger struct {
Ups upstream.Upstream
}
// Exchange implements aghnet.Exchanger interface for *Exchanger.
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if e.Ups == nil {
e.Ups = &TestErrUpstream{}
}
return e.Ups.Exchange(req)
}

View File

@@ -1,23 +0,0 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@@ -0,0 +1,181 @@
package aghtest
import (
"context"
"io/fs"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Interface Mocks
//
// Keep entities in this file in alphabetic order.
// Standard Library
// Package fs
// type check
var _ fs.FS = &FS{}
// FS is a mock [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// Package net
// type check
var _ net.Listener = (*Listener)(nil)
// Listener is a mock [net.Listener] implementation for tests.
type Listener struct {
OnAccept func() (conn net.Conn, err error)
OnAddr func() (addr net.Addr)
OnClose func() (err error)
}
// Accept implements the [net.Listener] interface for *Listener.
func (l *Listener) Accept() (conn net.Conn, err error) {
return l.OnAccept()
}
// Addr implements the [net.Listener] interface for *Listener.
func (l *Listener) Addr() (addr net.Addr) {
return l.OnAddr()
}
// Close implements the [net.Listener] interface for *Listener.
func (l *Listener) Close() (err error) {
return l.OnClose()
}
// Module adguard-home
// Package aghos
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}
// Package agh
// type check
var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
// ServiceWithConfig is a mock [agh.ServiceWithConfig] implementation for tests.
type ServiceWithConfig[ConfigType any] struct {
OnStart func() (err error)
OnShutdown func(ctx context.Context) (err error)
OnConfig func() (c ConfigType)
}
// Start implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[_]) Start() (err error) {
return s.OnStart()
}
// Shutdown implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) {
return s.OnShutdown(ctx)
}
// Config implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
return s.OnConfig()
}
// Module dnsproxy
// Package upstream
// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
OnClose func() (err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Close implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Close() (err error) {
return u.OnClose()
}

View File

@@ -0,0 +1,3 @@
package aghtest_test
// Put interface checks that cause import cycles here.

View File

@@ -1,46 +0,0 @@
package aghtest
import "io/fs"
// type check
var _ fs.FS = &FS{}
// FS is a mock fs.FS implementation to use in tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the fs.FS interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock fs.StatFS implementation to use in tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the fs.StatFS interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock fs.GlobFS implementation to use in tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the fs.GlobFS interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

View File

@@ -5,13 +5,19 @@ import (
"encoding/hex"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
// Additional Upstream Testing Utilities
// Upstream is a mock implementation of upstream.Upstream.
//
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
type Upstream struct {
// CName is a map of hostname to canonical name.
CName map[string][]string
@@ -19,13 +25,11 @@ type Upstream struct {
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
// Addr is the address for Address method.
Addr string
}
// Exchange implements the upstream.Upstream interface for *Upstream.
var _ upstream.Upstream = (*Upstream)(nil)
// Exchange implements the [upstream.Upstream] interface for *Upstream.
//
// TODO(a.garipov): Split further into handlers.
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
@@ -59,10 +63,6 @@ func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
for _, ip := range u.IPv6[name] {
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
}
case dns.TypePTR:
for _, name := range u.Reverse[name] {
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
}
}
if len(resp.Answer) == 0 {
resp.SetRcode(m, dns.RcodeNameError)
@@ -71,79 +71,157 @@ func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}
// Address implements upstream.Upstream interface for *Upstream.
// Address implements [upstream.Upstream] interface for *Upstream.
func (u *Upstream) Address() string {
return u.Addr
return "todo.upstream.example"
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
// lock protects reqNum.
lock sync.RWMutex
reqNum int
Block bool
// Close implements [upstream.Upstream] interface for *Upstream.
func (u *Upstream) Close() (err error) {
return nil
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
// MatchedResponse is a test helper that returns a response with answer if req
// has question type qt, and target targ. Otherwise, it returns nil.
//
// req must not be nil and req.Question must have a length of 1. Answer is
// interpreted in the following ways:
//
// - For A and AAAA queries, answer must be an IP address of the corresponding
// protocol version.
//
// - For PTR queries, answer should be a domain name in the response.
//
// If the answer does not correspond to the question type, MatchedResponse panics.
// Panics are used instead of [testing.TB], because the helper is intended to
// use in [UpstreamMock.OnExchange] callbacks, which are usually called in a
// separate goroutine.
//
// TODO(a.garipov): Consider adding version with DNS class as well.
func MatchedResponse(req *dns.Msg, qt uint16, targ, answer string) (resp *dns.Msg) {
if req == nil || len(req.Question) != 1 {
panic(fmt.Errorf("bad req: %+v", req))
}
m := &dns.Msg{}
m.SetReply(r)
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
q := req.Question[0]
targ = dns.Fqdn(targ)
if q.Qclass != dns.ClassINET || q.Qtype != qt || q.Name != targ {
return nil
}
respHdr := dns.RR_Header{
Name: targ,
Rrtype: qt,
Class: dns.ClassINET,
Ttl: 60,
}
resp = new(dns.Msg).SetReply(req)
switch qt {
case dns.TypeA:
resp.Answer = mustAnsA(respHdr, answer)
case dns.TypeAAAA:
resp.Answer = mustAnsAAAA(respHdr, answer)
case dns.TypePTR:
resp.Answer = []dns.RR{&dns.PTR{
Hdr: respHdr,
Ptr: answer,
}}
default:
panic(fmt.Errorf("aghtest: bad question type: %s", dns.Type(qt)))
}
return resp
}
// mustAnsA returns valid answer records if s is a valid IPv4 address.
// Otherwise, mustAnsA panics.
func mustAnsA(respHdr dns.RR_Header, s string) (ans []dns.RR) {
ip, err := netip.ParseAddr(s)
if err != nil || !ip.Is4() {
panic(fmt.Errorf("aghtest: bad A answer: %+v", s))
}
return []dns.RR{&dns.A{
Hdr: respHdr,
A: ip.AsSlice(),
}}
}
// mustAnsAAAA returns valid answer records if s is a valid IPv6 address.
// Otherwise, mustAnsAAAA panics.
func mustAnsAAAA(respHdr dns.RR_Header, s string) (ans []dns.RR) {
ip, err := netip.ParseAddr(s)
if err != nil || !ip.Is6() {
panic(fmt.Errorf("aghtest: bad AAAA answer: %+v", s))
}
return []dns.RR{&dns.AAAA{
Hdr: respHdr,
AAAA: ip.AsSlice(),
}}
}
// NewUpstreamMock returns an [*UpstreamMock], fields OnAddress and OnClose of
// which are set to stubs that return "upstream.example" and nil respectively.
// The field OnExchange is set to onExc.
func NewUpstreamMock(onExc func(req *dns.Msg) (resp *dns.Msg, err error)) (u *UpstreamMock) {
return &UpstreamMock{
OnAddress: func() (addr string) { return "upstream.example" },
OnExchange: onExc,
OnClose: func() (err error) { return nil },
}
}
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
// a different hash.
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
hash := sha256.Sum256([]byte(hostname))
hashStr := hex.EncodeToString(hash[:])
if !shouldBlock {
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
ans := &dns.TXT{
Hdr: dns.RR_Header{
Name: "",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
},
Txt: []string{hashStr},
}
respTmpl := &dns.Msg{
Answer: []dns.RR{ans},
}
return m, nil
return &UpstreamMock{
OnAddress: func() (addr string) { return "sbpc.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = respTmpl.Copy()
resp.SetReply(req)
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
return resp, nil
},
OnClose: func() (err error) { return nil },
}
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// ErrUpstream is the error returned from the [*UpstreamMock] created by
// [NewErrorUpstream].
const ErrUpstream errors.Error = "test upstream error"
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.reqNum
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct {
// The error returned by Exchange may be unwrapped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
// its Exchange method.
func NewErrorUpstream() (u *UpstreamMock) {
return &UpstreamMock{
OnAddress: func() (addr string) { return "error.upstream.example" },
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("test upstream error")
},
OnClose: func() (err error) { return nil },
}
}

View File

@@ -1,7 +1,50 @@
// Package aghtls contains utilities for work with TLS.
package aghtls
import "crypto/tls"
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/netip"
"github.com/AdguardTeam/golibs/log"
)
// init makes sure that the cipher name map is filled.
//
// TODO(a.garipov): Propose a similar API to crypto/tls.
func init() {
suites := tls.CipherSuites()
cipherSuites = make(map[string]uint16, len(suites))
for _, s := range suites {
cipherSuites[s.Name] = s.ID
}
log.Debug("tls: known ciphers: %q", cipherSuites)
}
// cipherSuites are a name-to-ID mapping of cipher suites from crypto/tls. It
// is filled by init. It must not be modified.
var cipherSuites map[string]uint16
// ParseCiphers parses a slice of cipher suites from cipher names.
func ParseCiphers(cipherNames []string) (cipherIDs []uint16, err error) {
if cipherNames == nil {
return nil, nil
}
cipherIDs = make([]uint16, 0, len(cipherNames))
for _, name := range cipherNames {
id, ok := cipherSuites[name]
if !ok {
return nil, fmt.Errorf("unknown cipher %q", name)
}
cipherIDs = append(cipherIDs, id)
}
return cipherIDs, nil
}
// SaferCipherSuites returns a set of default cipher suites with vulnerable and
// weak cipher suites removed.
@@ -28,3 +71,19 @@ func SaferCipherSuites() (safe []uint16) {
return safe
}
// CertificateHasIP returns true if cert has at least a single IP address among
// its subjectAltNames.
func CertificateHasIP(cert *x509.Certificate) (ok bool) {
if len(cert.IPAddresses) > 0 {
return true
}
for _, name := range cert.DNSNames {
if _, err := netip.ParseAddr(name); err == nil {
return true
}
}
return false
}

View File

@@ -0,0 +1,56 @@
package aghtls_test
import (
"crypto/tls"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func TestParseCiphers(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want []uint16
in []string
}{{
name: "nil",
wantErrMsg: "",
want: nil,
in: nil,
}, {
name: "empty",
wantErrMsg: "",
want: []uint16{},
in: []string{},
}, {}, {
name: "one",
wantErrMsg: "",
want: []uint16{tls.TLS_AES_128_GCM_SHA256},
in: []string{"TLS_AES_128_GCM_SHA256"},
}, {
name: "several",
wantErrMsg: "",
want: []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384},
in: []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"},
}, {
name: "bad",
wantErrMsg: `unknown cipher "bad_cipher"`,
want: nil,
in: []string{"bad_cipher"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := aghtls.ParseCiphers(tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
}

14
internal/aghtls/root.go Normal file
View File

@@ -0,0 +1,14 @@
package aghtls
import (
"crypto/x509"
)
// SystemRootCAs tries to load root certificates from the operating system. It
// returns nil in case nothing is found so that Go' crypto/x509 can use its
// default algorithm to find system root CA list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/1311.
func SystemRootCAs() (roots *x509.CertPool) {
return rootCAs()
}

View File

@@ -0,0 +1,56 @@
//go:build linux
package aghtls
import (
"crypto/x509"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
func rootCAs() (roots *x509.CertPool) {
// Directories with the system root certificates, which aren't supported by
// Go's crypto/x509.
dirs := []string{
// Entware.
"/opt/etc/ssl/certs",
}
roots = x509.NewCertPool()
for _, dir := range dirs {
dirEnts, err := os.ReadDir(dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
// TODO(a.garipov): Improve error handling here and in other places.
log.Error("aghtls: opening directory %q: %s", dir, err)
}
var rootsAdded bool
for _, de := range dirEnts {
var certData []byte
rootFile := filepath.Join(dir, de.Name())
certData, err = os.ReadFile(rootFile)
if err != nil {
log.Error("aghtls: reading root cert: %s", err)
} else {
if roots.AppendCertsFromPEM(certData) {
rootsAdded = true
} else {
log.Error("aghtls: could not add root from %q", rootFile)
}
}
}
if rootsAdded {
return roots
}
}
return nil
}

View File

@@ -0,0 +1,9 @@
//go:build !linux
package aghtls
import "crypto/x509"
func rootCAs() (roots *x509.CertPool) {
return nil
}

View File

@@ -1,5 +1,4 @@
//go:build freebsd || openbsd
// +build freebsd openbsd
package dhcpd
@@ -10,11 +9,10 @@ import (
// broadcast sends resp to the broadcast address specific for network interface.
func (c *dhcpConn) broadcast(respData []byte, peer *net.UDPAddr) (n int, err error) {
// Despite the fact that server4.NewIPv4UDPConn explicitly sets socket
// options to allow broadcasting, it also binds the connection to a
// specific interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo
// causes errors while writing to the addresses that belong to another
// interface. So, use the broadcast address specific for the interface
// bound.
// options to allow broadcasting, it also binds the connection to a specific
// interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo causes errors
// while writing to the addresses that belong to another interface. So, use
// the broadcast address specific for the interface bound.
peer.IP = c.bcastIP
return c.udpConn.WriteTo(respData, peer)

View File

@@ -1,5 +1,4 @@
//go:build freebsd || openbsd
// +build freebsd openbsd
package dhcpd

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
// +build aix darwin dragonfly linux netbsd solaris
//go:build darwin || linux
package dhcpd

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
// +build aix darwin dragonfly linux netbsd solaris
//go:build darwin || linux
package dhcpd

222
internal/dhcpd/config.go Normal file
View File

@@ -0,0 +1,222 @@
package dhcpd
import (
"fmt"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
)
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
// LocalDomainName is the domain name used for DHCP hosts. For example,
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
// when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"`
}
// DHCPServer - DHCP server interface
type DHCPServer interface {
// ResetLeases resets leases.
ResetLeases(leases []*Lease) (err error)
// GetLeases returns deep clones of the current leases.
GetLeases(flags GetLeasesFlags) (leases []*Lease)
// AddStaticLease - add a static lease
AddStaticLease(l *Lease) (err error)
// RemoveStaticLease - remove a static lease
RemoveStaticLease(l *Lease) (err error)
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
FindMACbyIP(ip net.IP) net.HardwareAddr
// WriteDiskConfig4 - copy disk configuration
WriteDiskConfig4(c *V4ServerConf)
// WriteDiskConfig6 - copy disk configuration
WriteDiskConfig6(c *V6ServerConf)
// Start - start server
Start() (err error)
// Stop - stop server
Stop() (err error)
getLeasesRef() []*Lease
}
// V4ServerConf - server configuration
type V4ServerConf struct {
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
GatewayIP netip.Addr `yaml:"gateway_ip" json:"gateway_ip"`
SubnetMask netip.Addr `yaml:"subnet_mask" json:"subnet_mask"`
// broadcastIP is the broadcasting address pre-calculated from the
// configured gateway IP and subnet mask.
broadcastIP netip.Addr
// The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart netip.Addr `yaml:"range_start" json:"range_start"`
RangeEnd netip.Addr `yaml:"range_end" json:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
// IP conflict detector: time (ms) to wait for ICMP reply
// 0: disable
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
// Custom Options.
//
// Option with arbitrary hexadecimal data:
// DEC_CODE hex HEX_DATA
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
//
// Option with IP data (only 1 IP is supported):
// DEC_CODE ip IP_ADDR
Options []string `yaml:"options" json:"-"`
ipRange *ipRange
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []netip.Addr // IPv4 addresses to return to DHCP clients as DNS server addresses
// subnet contains the DHCP server's subnet. The IP is the IP of the
// gateway.
subnet netip.Prefix
// notify is a way to signal to other components that leases have been
// changed. notify must be called outside of locked sections, since the
// clients might want to get the new data.
//
// TODO(a.garipov): This is utter madness and must be refactored. It just
// begs for deadlock bugs and other nastiness.
notify func(uint32)
}
// errNilConfig is an error returned by validation method if the config is nil.
const errNilConfig errors.Error = "nil config"
// ensureV4 returns an unmapped version of ip. An error is returned if the
// passed ip is not an IPv4.
func ensureV4(ip netip.Addr, kind string) (ip4 netip.Addr, err error) {
ip4 = ip.Unmap()
if !ip4.IsValid() || !ip4.Is4() {
return netip.Addr{}, fmt.Errorf("%v is not an IPv4 %s", ip, kind)
}
return ip4, nil
}
// Validate returns an error if c is not a valid configuration.
//
// TODO(e.burkov): Don't set the config fields when the server itself will stop
// containing the config.
func (c *V4ServerConf) Validate() (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if c == nil {
return errNilConfig
}
gatewayIP, err := ensureV4(c.GatewayIP, "address")
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
subnetMask, err := ensureV4(c.SubnetMask, "subnet mask")
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
maskLen, _ := net.IPMask(subnetMask.AsSlice()).Size()
c.subnet = netip.PrefixFrom(gatewayIP, maskLen)
c.broadcastIP = aghnet.BroadcastFromPref(c.subnet)
rangeStart, err := ensureV4(c.RangeStart, "address")
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
rangeEnd, err := ensureV4(c.RangeEnd, "address")
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
c.ipRange, err = newIPRange(rangeStart.AsSlice(), rangeEnd.AsSlice())
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// an annotation deferred already.
return err
}
if c.ipRange.contains(gatewayIP.AsSlice()) {
return fmt.Errorf("gateway ip %v in the ip range: %v-%v",
gatewayIP,
c.RangeStart,
c.RangeEnd,
)
}
if !c.subnet.Contains(rangeStart) {
return fmt.Errorf("range start %v is outside network %v",
c.RangeStart,
c.subnet,
)
}
if !c.subnet.Contains(rangeEnd) {
return fmt.Errorf("range end %v is outside network %v",
c.RangeEnd,
c.subnet,
)
}
return nil
}
// V6ServerConf - server configuration
type V6ServerConf struct {
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
// The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte
RangeStart net.IP `yaml:"range_start" json:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
ipStart net.IP // starting IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
// Server calls this function when leases data changes
notify func(uint32)
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -16,6 +15,8 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/mdlayher/ethernet"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
@@ -49,16 +50,15 @@ type dhcpConn struct {
}
// newDHCPConn creates the special connection for DHCP server.
func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) {
// Create the raw connection.
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
var ucast net.PacketConn
if ucast, err = raw.ListenPacket(ifi, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
return nil, fmt.Errorf("creating raw udp connection: %w", err)
}
// Create the UDP connection.
var bcast net.PacketConn
bcast, err = server4.NewIPv4UDPConn(ifi.Name, &net.UDPAddr{
bcast, err = server4.NewIPv4UDPConn(iface.Name, &net.UDPAddr{
// TODO(e.burkov): Listening on zeroes makes the server handle
// requests from all the interfaces. Inspect the ways to
// specify the interface-specific listening addresses.
@@ -73,16 +73,16 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error)
return &dhcpConn{
udpConn: bcast,
bcastIP: s.conf.broadcastIP,
bcastIP: s.conf.broadcastIP.AsSlice(),
rawConn: ucast,
srcMAC: ifi.HardwareAddr,
srcIP: s.conf.dnsIPAddrs[0],
srcMAC: iface.HardwareAddr,
srcIP: s.conf.dnsIPAddrs[0].AsSlice(),
}, nil
}
// wrapErrs is a helper to wrap the errors from two independent underlying
// connections.
func (c *dhcpConn) wrapErrs(action string, udpConnErr, rawConnErr error) (err error) {
func (*dhcpConn) wrapErrs(action string, udpConnErr, rawConnErr error) (err error) {
switch {
case udpConnErr != nil && rawConnErr != nil:
return errors.List(fmt.Sprintf("%s both connections", action), udpConnErr, rawConnErr)
@@ -128,7 +128,7 @@ func (c *dhcpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// connection.
return c.udpConn.WriteTo(p, addr)
default:
return 0, fmt.Errorf("peer is of unexpected type %T", addr)
return 0, fmt.Errorf("addr has an unexpected type %T", addr)
}
}
@@ -187,32 +187,20 @@ func (c *dhcpConn) SetWriteDeadline(t time.Time) error {
)
}
// ipv4DefaultTTL is the default Time to Live value as recommended by
// RFC-1700 (https://datatracker.ietf.org/doc/html/rfc1700) in seconds.
// ipv4DefaultTTL is the default Time to Live value in seconds as recommended by
// RFC-1700.
//
// See https://datatracker.ietf.org/doc/html/rfc1700.
const ipv4DefaultTTL = 64
// errInvalidPktDHCP is returned when the provided payload is not a valid DHCP
// packet.
const errInvalidPktDHCP errors.Error = "packet is not a valid dhcp packet"
// buildEtherPkt wraps the payload with IPv4, UDP and Ethernet frames. The
// payload is expected to be an encoded DHCP packet.
// buildEtherPkt wraps the payload with IPv4, UDP and Ethernet frames.
// Validation of the payload is a caller's responsibility.
func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []byte, err error) {
dhcpLayer := gopacket.NewPacket(payload, layers.LayerTypeDHCPv4, gopacket.DecodeOptions{
NoCopy: true,
}).Layer(layers.LayerTypeDHCPv4)
// Check if the decoding succeeded and the resulting layer doesn't
// contain any errors. It should guarantee panic-safe converting of the
// layer into gopacket.SerializableLayer.
if dhcpLayer == nil || dhcpLayer.LayerType() != layers.LayerTypeDHCPv4 {
return nil, errInvalidPktDHCP
}
udpLayer := &layers.UDP{
SrcPort: dhcpv4.ServerPort,
DstPort: dhcpv4.ClientPort,
}
ipv4Layer := &layers.IPv4{
Version: uint8(layers.IPProtocolIPv4),
Flags: layers.IPv4DontFragment,
@@ -225,6 +213,7 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
// Ignore the error since it's only returned for invalid network layer's
// type.
_ = udpLayer.SetNetworkLayerForChecksum(ipv4Layer)
ethLayer := &layers.Ethernet{
SrcMAC: c.srcMAC,
DstMAC: peer.HardwareAddr,
@@ -232,10 +221,19 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
}
buf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
setts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, ethLayer, ipv4Layer, udpLayer, dhcpLayer.(gopacket.SerializableLayer))
}
err = gopacket.SerializeLayers(
buf,
setts,
ethLayer,
ipv4Layer,
udpLayer,
gopacket.Payload(payload),
)
if err != nil {
return nil, fmt.Errorf("serializing layers: %w", err)
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -11,9 +10,11 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
//lint:ignore SA1019 See the TODO in go.mod.
"github.com/mdlayher/raw"
)
func TestDHCPConn_WriteTo_common(t *testing.T) {
@@ -45,7 +46,7 @@ func TestDHCPConn_WriteTo_common(t *testing.T) {
n, err := conn.WriteTo(nil, &unexpectedAddrType{})
require.Error(t, err)
testutil.AssertErrorMsg(t, "peer is of unexpected type *dhcpd.unexpectedAddrType", err)
testutil.AssertErrorMsg(t, "addr has an unexpected type *dhcpd.unexpectedAddrType", err)
assert.Zero(t, n)
})
}
@@ -89,14 +90,13 @@ func TestBuildEtherPkt(t *testing.T) {
}
})
t.Run("non-serializable", func(t *testing.T) {
t.Run("bad_payload", func(t *testing.T) {
// Create an invalid DHCP packet.
invalidPayload := []byte{1, 2, 3, 4}
pkt, err := conn.buildEtherPkt(invalidPayload, nil)
require.Error(t, err)
pkt, err := conn.buildEtherPkt(invalidPayload, peer)
require.NoError(t, err)
assert.ErrorIs(t, err, errInvalidPktDHCP)
assert.Empty(t, pkt)
assert.NotEmpty(t, pkt)
})
t.Run("serializing_error", func(t *testing.T) {

View File

@@ -32,7 +32,7 @@ func normalizeIP(ip net.IP) net.IP {
}
// Load lease table from DB
func (s *Server) dbLoad() (err error) {
func (s *server) dbLoad() (err error) {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
v6StaticLeases := []*Lease{}
@@ -132,7 +132,7 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
}
// Store lease table in DB
func (s *Server) dbStore() (err error) {
func (s *server) dbStore() (err error) {
// Use an empty slice here as opposed to nil so that it doesn't write
// "null" into the database file if leases are empty.
leases := []leaseJSON{}

View File

@@ -5,13 +5,12 @@ import (
"encoding/json"
"fmt"
"net"
"net/http"
"path/filepath"
"runtime"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/exp/slices"
)
const (
@@ -21,9 +20,19 @@ const (
// TODO(e.burkov): Remove it when static leases determining mechanism
// will be improved.
leaseExpireStatic = 1
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
DefaultDHCPTimeoutICMP = 1000
)
var webHandlersRegistered = false
// Currently used defaults for ifaceDNSAddrs.
const (
defaultMaxAttempts int = 10
defaultBackoff time.Duration = 500 * time.Millisecond
)
// Lease contains the necessary information about a DHCP lease
type Lease struct {
@@ -45,8 +54,8 @@ func (l *Lease) Clone() (clone *Lease) {
return &Lease{
Expiry: l.Expiry,
Hostname: l.Hostname,
HWAddr: netutil.CloneMAC(l.HWAddr),
IP: netutil.CloneIP(l.IP),
HWAddr: slices.Clone(l.HWAddr),
IP: slices.Clone(l.IP),
}
}
@@ -119,30 +128,6 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
return nil
}
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
// LocalDomainName is the domain name used for DHCP hosts. For example,
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
// when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"`
}
// OnLeaseChangedT is a callback for lease changes.
type OnLeaseChangedT func(flags int)
@@ -156,8 +141,68 @@ const (
LeaseChangedDBStore
)
// Server - the current state of the DHCP server
type Server struct {
// GetLeasesFlags are the flags for GetLeases.
type GetLeasesFlags uint8
// GetLeasesFlags values
const (
LeasesDynamic GetLeasesFlags = 0b01
LeasesStatic GetLeasesFlags = 0b10
LeasesAll = LeasesDynamic | LeasesStatic
)
// Interface is the DHCP server that deals with both IP address families.
type Interface interface {
Start() (err error)
Stop() (err error)
Enabled() (ok bool)
Leases(flags GetLeasesFlags) (leases []*Lease)
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
FindMACbyIP(ip net.IP) (mac net.HardwareAddr)
WriteDiskConfig(c *ServerConfig)
}
// MockInterface is a mock Interface implementation.
//
// TODO(e.burkov): Move to aghtest when the API stabilized.
type MockInterface struct {
OnStart func() (err error)
OnStop func() (err error)
OnEnabled func() (ok bool)
OnLeases func(flags GetLeasesFlags) (leases []*Lease)
OnSetOnLeaseChanged func(f OnLeaseChangedT)
OnFindMACbyIP func(ip net.IP) (mac net.HardwareAddr)
OnWriteDiskConfig func(c *ServerConfig)
}
var _ Interface = (*MockInterface)(nil)
// Start implements the Interface for *MockInterface.
func (s *MockInterface) Start() (err error) { return s.OnStart() }
// Stop implements the Interface for *MockInterface.
func (s *MockInterface) Stop() (err error) { return s.OnStop() }
// Enabled implements the Interface for *MockInterface.
func (s *MockInterface) Enabled() (ok bool) { return s.OnEnabled() }
// Leases implements the Interface for *MockInterface.
func (s *MockInterface) Leases(flags GetLeasesFlags) (ls []*Lease) { return s.OnLeases(flags) }
// SetOnLeaseChanged implements the Interface for *MockInterface.
func (s *MockInterface) SetOnLeaseChanged(f OnLeaseChangedT) { s.OnSetOnLeaseChanged(f) }
// FindMACbyIP implements the Interface for *MockInterface.
func (s *MockInterface) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return s.OnFindMACbyIP(ip) }
// WriteDiskConfig implements the Interface for *MockInterface.
func (s *MockInterface) WriteDiskConfig(c *ServerConfig) { s.OnWriteDiskConfig(c) }
// server is the DHCP service that handles DHCPv4, DHCPv6, and HTTP API.
type server struct {
srv4 DHCPServer
srv6 DHCPServer
@@ -169,27 +214,13 @@ type Server struct {
onLeaseChanged []OnLeaseChangedT
}
// GetLeasesFlags are the flags for GetLeases.
type GetLeasesFlags uint8
// type check
var _ Interface = (*server)(nil)
// GetLeasesFlags values
const (
LeasesDynamic GetLeasesFlags = 0b0001
LeasesStatic GetLeasesFlags = 0b0010
LeasesAll = LeasesDynamic | LeasesStatic
)
// ServerInterface is an interface for servers.
type ServerInterface interface {
Enabled() (ok bool)
Leases(flags GetLeasesFlags) (leases []*Lease)
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
}
// Create - create object
func Create(conf *ServerConfig) (s *Server, err error) {
s = &Server{
// Create initializes and returns the DHCP server handling both address
// families. It also registers the corresponding HTTP API endpoints.
func Create(conf *ServerConfig) (s *server, err error) {
s = &server{
conf: &ServerConfig{
ConfigModified: conf.ConfigModified,
@@ -204,35 +235,22 @@ func Create(conf *ServerConfig) (s *Server, err error) {
},
}
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
if runtime.GOOS == "windows" {
// Our DHCP server doesn't work on Windows yet, so
// signal that to the front with an HTTP 501.
//
// TODO(a.garipov): This needs refactoring. We
// shouldn't even try and initialize a DHCP server on
// Windows, but there are currently too many
// interconnected parts--such as HTTP handlers and
// frontend--to make that work properly.
s.registerNotImplementedHandlers()
} else {
s.registerHandlers()
}
webHandlersRegistered = true
}
// TODO(e.burkov): Don't register handlers, see TODO on
// [aghhttp.RegisterFunc].
s.registerHandlers()
v4conf := conf.Conf4
v4conf.Enabled = s.conf.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
}
v4conf.InterfaceName = s.conf.InterfaceName
v4conf.notify = s.onNotify
s.srv4, err = v4Create(v4conf)
v4conf.Enabled = s.conf.Enabled && v4conf.RangeStart.IsValid()
s.srv4, err = v4Create(&v4conf)
if err != nil {
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
if v4conf.Enabled {
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
}
log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err)
}
v6conf := conf.Conf6
@@ -265,12 +283,12 @@ func Create(conf *ServerConfig) (s *Server, err error) {
}
// Enabled returns true when the server is enabled.
func (s *Server) Enabled() (ok bool) {
func (s *server) Enabled() (ok bool) {
return s.conf.Enabled
}
// resetLeases resets all leases in the lease database.
func (s *Server) resetLeases() (err error) {
func (s *server) resetLeases() (err error) {
err = s.srv4.ResetLeases(nil)
if err != nil {
return err
@@ -287,7 +305,7 @@ func (s *Server) resetLeases() (err error) {
}
// server calls this function after DB is updated
func (s *Server) onNotify(flags uint32) {
func (s *server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore {
err := s.dbStore()
if err != nil {
@@ -301,31 +319,28 @@ func (s *Server) onNotify(flags uint32) {
}
// SetOnLeaseChanged - set callback
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
func (s *server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
}
func (s *Server) notify(flags int) {
if len(s.onLeaseChanged) == 0 {
return
}
func (s *server) notify(flags int) {
for _, f := range s.onLeaseChanged {
f(flags)
}
}
// WriteDiskConfig - write configuration
func (s *Server) WriteDiskConfig(c *ServerConfig) {
func (s *server) WriteDiskConfig(c *ServerConfig) {
c.Enabled = s.conf.Enabled
c.InterfaceName = s.conf.InterfaceName
c.LocalDomainName = s.conf.LocalDomainName
s.srv4.WriteDiskConfig4(&c.Conf4)
s.srv6.WriteDiskConfig6(&c.Conf6)
}
// Start will listen on port 67 and serve DHCP requests.
func (s *Server) Start() (err error) {
func (s *server) Start() (err error) {
err = s.srv4.Start()
if err != nil {
return err
@@ -340,7 +355,7 @@ func (s *Server) Start() (err error) {
}
// Stop closes the listening UDP socket
func (s *Server) Stop() (err error) {
func (s *server) Stop() (err error) {
err = s.srv4.Stop()
if err != nil {
return err
@@ -356,12 +371,12 @@ func (s *Server) Stop() (err error) {
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
// concurrent use.
func (s *Server) Leases(flags GetLeasesFlags) (leases []*Lease) {
func (s *server) Leases(flags GetLeasesFlags) (leases []*Lease) {
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
func (s *server) FindMACbyIP(ip net.IP) net.HardwareAddr {
if ip.To4() != nil {
return s.srv4.FindMACbyIP(ip)
}
@@ -369,6 +384,6 @@ func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
}
// AddStaticLease - add static v4 lease
func (s *Server) AddStaticLease(l *Lease) error {
func (s *server) AddStaticLease(l *Lease) error {
return s.srv4.AddStaticLease(l)
}

View File

@@ -1,23 +1,22 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"net"
"net/netip"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
testutil.DiscardLogOutput(m)
}
func testNotify(flags uint32) {
@@ -26,18 +25,18 @@ func testNotify(flags uint32) {
// Leases database store/load.
func TestDB(t *testing.T) {
var err error
s := Server{
s := server{
conf: &ServerConfig{
DBFilePath: dbFilename,
},
}
s.srv4, err = v4Create(V4ServerConf{
s.srv4, err = v4Create(&V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
RangeStart: netip.MustParseAddr("192.168.10.100"),
RangeEnd: netip.MustParseAddr("192.168.10.200"),
GatewayIP: netip.MustParseAddr("192.168.10.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
notify: testNotify,
})
require.NoError(t, err)
@@ -88,32 +87,6 @@ func TestDB(t *testing.T) {
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
}
func TestIsValidSubnetMask(t *testing.T) {
testCases := []struct {
mask net.IP
want bool
}{{
mask: net.IP{255, 255, 255, 0},
want: true,
}, {
mask: net.IP{255, 255, 254, 0},
want: true,
}, {
mask: net.IP{255, 255, 252, 0},
want: true,
}, {
mask: net.IP{255, 255, 253, 0},
}, {
mask: net.IP{255, 255, 255, 1},
}}
for _, tc := range testCases {
t.Run(tc.mask.String(), func(t *testing.T) {
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
})
}
}
func TestNormalizeLeases(t *testing.T) {
dynLeases := []*Lease{{
HWAddr: net.HardwareAddr{1, 2, 3, 4},
@@ -140,41 +113,41 @@ func TestNormalizeLeases(t *testing.T) {
func TestV4Server_badRange(t *testing.T) {
testCases := []struct {
name string
gatewayIP netip.Addr
subnetMask netip.Addr
wantErrMsg string
gatewayIP net.IP
subnetMask net.IP
}{{
name: "gateway_in_range",
name: "gateway_in_range",
gatewayIP: netip.MustParseAddr("192.168.10.120"),
subnetMask: netip.MustParseAddr("255.255.255.0"),
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
"192.168.10.20-192.168.10.200",
gatewayIP: net.IP{192, 168, 10, 120},
subnetMask: net.IP{255, 255, 255, 0},
}, {
name: "outside_range_start",
name: "outside_range_start",
gatewayIP: netip.MustParseAddr("192.168.10.1"),
subnetMask: netip.MustParseAddr("255.255.255.240"),
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
"192.168.10.1/28",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 240},
}, {
name: "outside_range_end",
name: "outside_range_end",
gatewayIP: netip.MustParseAddr("192.168.10.1"),
subnetMask: netip.MustParseAddr("255.255.255.224"),
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
"192.168.10.1/27",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 224},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 20},
RangeEnd: net.IP{192, 168, 10, 200},
RangeStart: netip.MustParseAddr("192.168.10.20"),
RangeEnd: netip.MustParseAddr("192.168.10.200"),
GatewayIP: tc.gatewayIP,
SubnetMask: tc.subnetMask,
notify: testNotify,
}
_, err := v4Create(conf)
_, err := v4Create(&conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
@@ -183,7 +156,7 @@ func TestV4Server_badRange(t *testing.T) {
// cloneUDPAddr returns a deep copy of a.
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
return &net.UDPAddr{
IP: netutil.CloneIP(a.IP),
IP: slices.Clone(a.IP),
Port: a.Port,
Zone: a.Zone,
}

View File

@@ -1,36 +0,0 @@
package dhcpd
import (
"encoding/binary"
"fmt"
"net"
)
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
if ip == nil {
return nil, fmt.Errorf("%v is not an IP address", ip)
}
ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
}
return ip4, nil
}
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
func isValidSubnetMask(mask net.IP) bool {
var n uint32
n = binary.BigEndian.Uint32(mask)
for i := 0; i != 32; i++ {
if n == 0 {
break
}
if (n & 0x80000000) == 0 {
return false
}
n <<= 1
}
return true
}

View File

@@ -1,37 +1,36 @@
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
type v4ServerConfJSON struct {
GatewayIP net.IP `json:"gateway_ip"`
SubnetMask net.IP `json:"subnet_mask"`
RangeStart net.IP `json:"range_start"`
RangeEnd net.IP `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"`
GatewayIP netip.Addr `json:"gateway_ip"`
SubnetMask netip.Addr `json:"subnet_mask"`
RangeStart netip.Addr `json:"range_start"`
RangeEnd netip.Addr `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
func (j *v4ServerConfJSON) toServerConf() *V4ServerConf {
if j == nil {
return V4ServerConf{}
return &V4ServerConf{}
}
return V4ServerConf{
return &V4ServerConf{
GatewayIP: j.GatewayIP,
SubnetMask: j.SubnetMask,
RangeStart: j.RangeStart,
@@ -41,8 +40,8 @@ func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
}
type v6ServerConfJSON struct {
RangeStart net.IP `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"`
RangeStart netip.Addr `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
@@ -51,7 +50,7 @@ func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
}
return V6ServerConf{
RangeStart: j.RangeStart,
RangeStart: j.RangeStart.AsSlice(),
LeaseDuration: j.LeaseDuration,
}
}
@@ -66,7 +65,7 @@ type dhcpStatusResponse struct {
Enabled bool `json:"enabled"`
}
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status := &dhcpStatusResponse{
Enabled: s.conf.Enabled,
IfaceName: s.conf.InterfaceName,
@@ -80,41 +79,29 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status.Leases = s.Leases(LeasesDynamic)
status.StaticLeases = s.Leases(LeasesStatic)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal DHCP status json: %s",
err,
)
}
_ = aghhttp.WriteJSONResponse(w, r, status)
}
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
func (s *server) enableDHCP(ifaceName string) (code int, err error) {
var hasStaticIP bool
hasStaticIP, err = aghnet.IfaceHasStaticIP(ifaceName)
if err != nil {
if errors.Is(err, os.ErrPermission) {
// ErrPermission may happen here on Linux systems where
// AdGuard Home is installed using Snap. That doesn't
// necessarily mean that the machine doesn't have
// a static IP, so we can assume that it has and go on.
// If the machine doesn't, we'll get an error later.
// ErrPermission may happen here on Linux systems where AdGuard Home
// is installed using Snap. That doesn't necessarily mean that the
// machine doesn't have a static IP, so we can assume that it has
// and go on. If the machine doesn't, we'll get an error later.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
//
// TODO(a.garipov): I was thinking about moving this
// into IfaceHasStaticIP, but then we wouldn't be able
// to log it. Think about it more.
// TODO(a.garipov): I was thinking about moving this into
// IfaceHasStaticIP, but then we wouldn't be able to log it. Think
// about it more.
log.Info("error while checking static ip: %s; "+
"assuming machine has static ip and going on", err)
hasStaticIP = true
} else if errors.Is(err, aghnet.ErrNoStaticIPInfo) {
// Couldn't obtain a definitive answer. Assume static
// IP an go on.
// Couldn't obtain a definitive answer. Assume static IP an go on.
log.Info("can't check for static ip; " +
"assuming machine has static ip and going on")
hasStaticIP = true
@@ -149,34 +136,39 @@ type dhcpServerConfigJSON struct {
Enabled aghalg.NullBool `json:"enabled"`
}
func (s *Server) handleDHCPSetConfigV4(
func (s *server) handleDHCPSetConfigV4(
conf *dhcpServerConfigJSON,
) (srv4 DHCPServer, enabled bool, err error) {
) (srv DHCPServer, enabled bool, err error) {
if conf.V4 == nil {
return nil, false, nil
}
v4Conf := v4JSONToServerConf(conf.V4)
v4Conf := conf.V4.toServerConf()
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
if len(v4Conf.RangeStart) == 0 {
if !v4Conf.RangeStart.IsValid() {
v4Conf.Enabled = false
}
enabled = v4Conf.Enabled
v4Conf.InterfaceName = conf.InterfaceName
c4 := V4ServerConf{}
s.srv4.WriteDiskConfig4(&c4)
// Set the default values for the fields not configurable via web API.
c4 := &V4ServerConf{
notify: s.onNotify,
ICMPTimeout: s.conf.Conf4.ICMPTimeout,
Options: s.conf.Conf4.Options,
}
s.srv4.WriteDiskConfig4(c4)
v4Conf.notify = c4.notify
v4Conf.ICMPTimeout = c4.ICMPTimeout
v4Conf.Options = c4.Options
srv4, err = v4Create(v4Conf)
srv4, err := v4Create(v4Conf)
return srv4, enabled, err
return srv4, srv4.enabled(), err
}
func (s *Server) handleDHCPSetConfigV6(
func (s *server) handleDHCPSetConfigV6(
conf *dhcpServerConfigJSON,
) (srv6 DHCPServer, enabled bool, err error) {
if conf.V6 == nil {
@@ -205,7 +197,7 @@ func (s *Server) handleDHCPSetConfigV6(
return srv6, enabled, err
}
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
conf := &dhcpServerConfigJSON{}
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
conf.InterfaceName = s.conf.InterfaceName
@@ -244,22 +236,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
if conf.Enabled != aghalg.NBNull {
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
}
if conf.InterfaceName != "" {
s.conf.InterfaceName = conf.InterfaceName
}
if srv4 != nil {
s.srv4 = srv4
}
if srv6 != nil {
s.srv6 = srv6
}
s.setConfFromJSON(conf, srv4, srv6)
s.conf.ConfigModified()
err = s.dbLoad()
@@ -278,16 +255,36 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
}
}
type netInterfaceJSON struct {
Name string `json:"name"`
HardwareAddr string `json:"hardware_address"`
Flags string `json:"flags"`
GatewayIP net.IP `json:"gateway_ip"`
Addrs4 []net.IP `json:"ipv4_addresses"`
Addrs6 []net.IP `json:"ipv6_addresses"`
// setConfFromJSON sets configuration parameters in s from the new configuration
// decoded from JSON.
func (s *server) setConfFromJSON(conf *dhcpServerConfigJSON, srv4, srv6 DHCPServer) {
if conf.Enabled != aghalg.NBNull {
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
}
if conf.InterfaceName != "" {
s.conf.InterfaceName = conf.InterfaceName
}
if srv4 != nil {
s.srv4 = srv4
}
if srv6 != nil {
s.srv6 = srv6
}
}
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
type netInterfaceJSON struct {
Name string `json:"name"`
HardwareAddr string `json:"hardware_address"`
Flags string `json:"flags"`
GatewayIP netip.Addr `json:"gateway_ip"`
Addrs4 []netip.Addr `json:"ipv4_addresses"`
Addrs6 []netip.Addr `json:"ipv6_addresses"`
}
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]netInterfaceJSON{}
ifaces, err := net.Interfaces()
@@ -345,13 +342,18 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return
}
// ignore link-local
//
// TODO(e.burkov): Try to listen DHCP on LLA as well.
if ipnet.IP.IsLinkLocalUnicast() {
continue
}
if ipnet.IP.To4() != nil {
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
if ip4 := ipnet.IP.To4(); ip4 != nil {
addr := netip.AddrFrom4(*(*[4]byte)(ip4))
jsonIface.Addrs4 = append(jsonIface.Addrs4, addr)
} else {
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
addr := netip.AddrFrom16(*(*[16]byte)(ipnet.IP))
jsonIface.Addrs6 = append(jsonIface.Addrs6, addr)
}
}
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
@@ -406,31 +408,37 @@ type dhcpSearchResult struct {
V6 dhcpSearchV6Result `json:"v6"`
}
// Perform the following tasks:
// . Search for another DHCP server running
// . Check if a static IP is configured for the network interface
// Respond with results
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
// findActiveServerReq is the JSON structure for the request to find active DHCP
// servers.
type findActiveServerReq struct {
Interface string `json:"interface"`
}
// handleDHCPFindActiveServer performs the following tasks:
// 1. searches for another DHCP server in the network;
// 2. check if a static IP is configured for the network interface;
// 3. responds with the results.
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
if aghhttp.WriteTextPlainDeprecated(w, r) {
return
}
req := &findActiveServerReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err)
log.Error(msg)
http.Error(w, msg, http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
ifaceName := strings.TrimSpace(string(body))
ifaceName := req.Interface
if ifaceName == "" {
msg := "empty interface name specified"
log.Error(msg)
http.Error(w, msg, http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "empty interface name")
return
}
result := dhcpSearchResult{
result := &dhcpSearchResult{
V4: dhcpSearchV4Result{
OtherServer: dhcpSearchOtherResult{
Found: "no",
@@ -455,6 +463,14 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
}
setOtherDHCPResult(ifaceName, result)
_ = aghhttp.WriteJSONResponse(w, r, result)
}
// setOtherDHCPResult sets the results of the check for another DHCP server in
// result.
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
if err4 != nil {
result.V4.OtherServer.Found = "error"
@@ -462,27 +478,16 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
} else if found4 {
result.V4.OtherServer.Found = "yes"
}
if err6 != nil {
result.V6.OtherServer.Found = "error"
result.V6.OtherServer.Error = err6.Error()
} else if found6 {
result.V6.OtherServer.Found = "yes"
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to marshal DHCP found json: %s",
err,
)
}
}
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
@@ -497,21 +502,16 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
return
}
ip4 := l.IP.To4()
if ip4 == nil {
var srv DHCPServer
if ip4 := l.IP.To4(); ip4 != nil {
l.IP = ip4
srv = s.srv4
} else {
l.IP = l.IP.To16()
err = s.srv6.AddStaticLease(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
return
srv = s.srv6
}
l.IP = ip4
err = s.srv4.AddStaticLease(l)
err = srv.AddStaticLease(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -519,7 +519,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
}
}
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
@@ -556,14 +556,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
}
}
const (
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
DefaultDHCPTimeoutICMP = 1000
)
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
err := s.Stop()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
@@ -587,7 +580,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
DBFilePath: s.conf.DBFilePath,
}
v4conf := V4ServerConf{
v4conf := &V4ServerConf{
LeaseDuration: DefaultDHCPLeaseTTL,
ICMPTimeout: DefaultDHCPTimeoutICMP,
notify: s.onNotify,
@@ -603,7 +596,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
s.conf.ConfigModified()
}
func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
err := s.resetLeases()
if err != nil {
msg := "resetting leases: %s"
@@ -613,7 +606,11 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) registerHandlers() {
func (s *server) registerHandlers() {
if s.conf.HTTPRegister == nil {
return
}
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
@@ -623,44 +620,3 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
}
// jsonError is a generic JSON error response.
//
// TODO(a.garipov): Merge together with the implementations in .../home and
// other packages after refactoring the web handler registering.
type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// notImplemented returns a handler that replies to any request with an HTTP 501
// Not Implemented status and a JSON error with the provided message msg.
//
// TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server.
func (s *Server) notImplemented(msg string) (f func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
err := json.NewEncoder(w).Encode(&jsonError{
Message: msg,
})
if err != nil {
log.Debug("writing 501 json response: %s", err)
}
}
}
func (s *Server) registerNotImplementedHandlers() {
h := s.notImplemented("dhcp is not supported on windows")
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", h)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", h)
}

View File

@@ -0,0 +1,48 @@
//go:build windows
package dhcpd
import (
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
// jsonError is a generic JSON error response.
//
// TODO(a.garipov): Merge together with the implementations in .../home and
// other packages after refactoring the web handler registering.
type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// notImplemented is a handler that replies to any request with an HTTP 501 Not
// Implemented status and a JSON error with the provided message msg.
//
// TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server.
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
_ = aghhttp.WriteJSONResponseCode(w, r, http.StatusNotImplemented, &jsonError{
Message: aghos.Unsupported("dhcp").Error(),
})
}
// registerHandlers sets the handlers for DHCP HTTP API that always respond with
// an HTTP 501, since DHCP server doesn't work on Windows yet.
//
// TODO(a.garipov): This needs refactoring. We shouldn't even try and
// initialize a DHCP server on Windows, but there are currently too many
// interconnected parts--such as HTTP handlers and frontend--to make that work
// properly.
func (s *server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
}

View File

@@ -1,23 +1,28 @@
//go:build windows
package dhcpd
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_notImplemented(t *testing.T) {
s := &Server{}
h := s.notImplemented("never!")
s := &server{}
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
require.NoError(t, err)
h(w, r)
s.notImplemented(w, r)
assert.Equal(t, http.StatusNotImplemented, w.Code)
assert.Equal(t, `{"message":"never!"}`+"\n", w.Body.String())
wantStr := fmt.Sprintf("{%q:%q}", "message", aghos.Unsupported("dhcp"))
assert.JSONEq(t, wantStr, w.Body.String())
}

View File

@@ -27,6 +27,8 @@ const maxRangeLen = math.MaxUint32
// newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen.
//
// TODO(e.burkov): Use netip.Addr.
func newIPRange(start, end net.IP) (r *ipRange, err error) {
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -9,26 +8,31 @@ import (
"net"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/insomniacslk/dhcp/dhcpv4"
)
// The aliases for DHCP option types available for explicit declaration.
//
// TODO(e.burkov): Add an option for classless routes.
const (
hexTyp = "hex"
ipTyp = "ip"
ipsTyp = "ips"
textTyp = "text"
typDel = "del"
typBool = "bool"
typDur = "dur"
typHex = "hex"
typIP = "ip"
typIPs = "ips"
typText = "text"
typU8 = "u8"
typU16 = "u16"
)
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string. For
// example:
//
// 252 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
//
// parseDHCPOptionHex parses a DHCP option as a hex-encoded string.
func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
var data []byte
data, err = hex.DecodeString(s)
@@ -39,15 +43,12 @@ func parseDHCPOptionHex(s string) (val dhcpv4.OptionValue, err error) {
return dhcpv4.OptionGeneric{Data: data}, nil
}
// parseDHCPOptionIP parses a DHCP option as a single IP address. For example:
//
// 6 ip 192.168.1.1
//
// parseDHCPOptionIP parses a DHCP option as a single IP address.
func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
var ip net.IP
// All DHCPv4 options require IPv4, so don't put the 16-byte version.
// Otherwise, the clients will receive weird data that looks like four
// IPv4 addresses.
// Otherwise, the clients will receive weird data that looks like four IPv4
// addresses.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2688.
if ip, err = netutil.ParseIPv4(s); err != nil {
@@ -58,133 +59,343 @@ func parseDHCPOptionIP(s string) (val dhcpv4.OptionValue, err error) {
}
// parseDHCPOptionIPs parses a DHCP option as a comma-separates list of IP
// addresses. For example:
//
// 6 ips 192.168.1.1,192.168.1.2
//
// addresses.
func parseDHCPOptionIPs(s string) (val dhcpv4.OptionValue, err error) {
var ips dhcpv4.IPs
var ip net.IP
var ip dhcpv4.OptionValue
for i, ipStr := range strings.Split(s, ",") {
// See notes in the ipDHCPOptionParserHandler.
if ip, err = netutil.ParseIPv4(ipStr); err != nil {
ip, err = parseDHCPOptionIP(ipStr)
if err != nil {
return nil, fmt.Errorf("parsing ip at index %d: %w", i, err)
}
ips = append(ips, ip)
ips = append(ips, net.IP(ip.(dhcpv4.IP)))
}
return ips, nil
}
// parseDHCPOptionText parses a DHCP option as a simple UTF-8 encoded
// text. For example:
//
// 252 text http://192.168.1.1/wpad.dat
//
func parseDHCPOptionText(s string) (val dhcpv4.OptionValue) {
return dhcpv4.OptionGeneric{Data: []byte(s)}
// parseDHCPOptionDur parses a DHCP option as a duration in a human-readable
// form.
func parseDHCPOptionDur(s string) (val dhcpv4.OptionValue, err error) {
var v timeutil.Duration
err = v.UnmarshalText([]byte(s))
if err != nil {
return nil, fmt.Errorf("decoding dur: %w", err)
}
return dhcpv4.Duration(v.Duration), nil
}
// parseDHCPOption parses an option. See the documentation of parseDHCPOption*
// for more info.
func parseDHCPOption(s string) (opt dhcpv4.Option, err error) {
// parseDHCPOptionUint parses a DHCP option as an unsigned integer. bitSize is
// expected to be 8 or 16.
func parseDHCPOptionUint(s string, bitSize int) (val dhcpv4.OptionValue, err error) {
var v uint64
v, err = strconv.ParseUint(s, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("decoding u%d: %w", bitSize, err)
}
switch bitSize {
case 8:
return dhcpv4.OptionGeneric{Data: []byte{uint8(v)}}, nil
case 16:
return dhcpv4.Uint16(v), nil
default:
return nil, fmt.Errorf("unsupported size of integer %d", bitSize)
}
}
// parseDHCPOptionBool parses a DHCP option as a boolean value. See
// [strconv.ParseBool] for available values.
func parseDHCPOptionBool(s string) (val dhcpv4.OptionValue, err error) {
var v bool
v, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("decoding bool: %w", err)
}
rawVal := [1]byte{}
if v {
rawVal[0] = 1
}
return dhcpv4.OptionGeneric{Data: rawVal[:]}, nil
}
// parseDHCPOptionVal parses a DHCP option value considering typ.
func parseDHCPOptionVal(typ, valStr string) (val dhcpv4.OptionValue, err error) {
switch typ {
case typBool:
val, err = parseDHCPOptionBool(valStr)
case typDel:
val = dhcpv4.OptionGeneric{Data: nil}
case typDur:
val, err = parseDHCPOptionDur(valStr)
case typHex:
val, err = parseDHCPOptionHex(valStr)
case typIP:
val, err = parseDHCPOptionIP(valStr)
case typIPs:
val, err = parseDHCPOptionIPs(valStr)
case typText:
val = dhcpv4.String(valStr)
case typU8:
val, err = parseDHCPOptionUint(valStr, 8)
case typU16:
val, err = parseDHCPOptionUint(valStr, 16)
default:
err = fmt.Errorf("unknown option type %q", typ)
}
return val, err
}
// parseDHCPOption parses an option. For the del option value is ignored. The
// examples of possible option strings:
//
// - 1 bool true
// - 2 del
// - 3 dur 2h5s
// - 4 hex 736f636b733a2f2f70726f78792e6578616d706c652e6f7267
// - 5 ip 192.168.1.1
// - 6 ips 192.168.1.1,192.168.1.2
// - 7 text http://192.168.1.1/wpad.dat
// - 8 u8 255
// - 9 u16 65535
func parseDHCPOption(s string) (code dhcpv4.OptionCode, val dhcpv4.OptionValue, err error) {
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
s = strings.TrimSpace(s)
parts := strings.SplitN(s, " ", 3)
if len(parts) < 3 {
return opt, errors.Error("need at least three fields")
var valStr string
if pl := len(parts); pl < 3 {
if pl < 2 || parts[1] != typDel {
return nil, nil, errors.Error("bad option format")
}
} else {
valStr = parts[2]
}
var code64 uint64
code64, err = strconv.ParseUint(parts[0], 10, 8)
if err != nil {
return opt, fmt.Errorf("parsing option code: %w", err)
}
var optVal dhcpv4.OptionValue
switch typ, val := parts[1], parts[2]; typ {
case hexTyp:
optVal, err = parseDHCPOptionHex(val)
case ipTyp:
optVal, err = parseDHCPOptionIP(val)
case ipsTyp:
optVal, err = parseDHCPOptionIPs(val)
case textTyp:
optVal = parseDHCPOptionText(val)
default:
return opt, fmt.Errorf("unknown option type %q", typ)
return nil, nil, fmt.Errorf("parsing option code: %w", err)
}
val, err = parseDHCPOptionVal(parts[1], valStr)
if err != nil {
return opt, err
// Don't wrap an error since it's informative enough as is and there
// also the deferred annotation.
return nil, nil, err
}
return dhcpv4.Option{
Code: dhcpv4.GenericOptionCode(code64),
Value: optVal,
}, nil
return dhcpv4.GenericOptionCode(code64), val, nil
}
// prepareOptions builds the set of DHCP options according to host requirements
// document and values from conf.
func prepareOptions(conf V4ServerConf) (opts dhcpv4.Options) {
// Set default values for host configuration parameters listed in Appendix
// A of RFC-2131. Those parameters, if requested by client, should be
// returned with values defined by Host Requirements Document.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#appendix-A.
//
// See also https://datatracker.ietf.org/doc/html/rfc1122,
// https://datatracker.ietf.org/doc/html/rfc1123, and
// https://datatracker.ietf.org/doc/html/rfc2132.
opts = dhcpv4.Options{
func (s *v4Server) prepareOptions() {
// Set default values of host configuration parameters listed in Appendix A
// of RFC-2131.
s.implicitOpts = dhcpv4.OptionsFromList(
// IP-Layer Per Host
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
// Set the current recommended default time to live for the
// Internet Protocol which is 64, see
// https://datatracker.ietf.org/doc/html/rfc1700.
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
// An Internet host that includes embedded gateway code MUST have a
// configuration switch to disable the gateway function, and this switch
// MUST default to the non-gateway mode.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.5.
dhcpv4.OptGeneric(dhcpv4.OptionIPForwarding, []byte{0x0}),
// A host that supports non-local source-routing MUST have a
// configurable switch to disable forwarding, and this switch MUST
// default to disabled.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.5.
dhcpv4.OptGeneric(dhcpv4.OptionNonLocalSourceRouting, []byte{0x0}),
// Do not set the Policy Filter Option since it only makes sense when
// the non-local source routing is enabled.
// The minimum legal value is 576.
//
// See https://datatracker.ietf.org/doc/html/rfc2132#section-4.4.
dhcpv4.Option{
Code: dhcpv4.OptionMaximumDatagramAssemblySize,
Value: dhcpv4.Uint16(576),
},
// Set the current recommended default time to live for the Internet
// Protocol which is 64.
//
// See https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml#ip-parameters-2.
dhcpv4.OptGeneric(dhcpv4.OptionDefaultIPTTL, []byte{0x40}),
// For example, after the PTMU estimate is decreased, the timeout should
// be set to 10 minutes; once this timer expires and a larger MTU is
// attempted, the timeout can be set to a much smaller value.
//
// See https://datatracker.ietf.org/doc/html/rfc1191#section-6.6.
dhcpv4.Option{
Code: dhcpv4.OptionPathMTUAgingTimeout,
Value: dhcpv4.Duration(10 * time.Minute),
},
// There is a table describing the MTU values representing all major
// data-link technologies in use in the Internet so that each set of
// similar MTUs is associated with a plateau value equal to the lowest
// MTU in the group.
//
// See https://datatracker.ietf.org/doc/html/rfc1191#section-7.
dhcpv4.OptGeneric(dhcpv4.OptionPathMTUPlateauTable, []byte{
0x0, 0x44,
0x1, 0x28,
0x1, 0xFC,
0x3, 0xEE,
0x5, 0xD4,
0x7, 0xD2,
0x11, 0x0,
0x1F, 0xE6,
0x45, 0xFA,
}),
// IP-Layer Per Interface
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
// The all-routers address is preferred wherever possible, see
// https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
dhcpv4.OptionRouterSolicitationAddress.Code(): netutil.IPv4allrouter(),
dhcpv4.OptionBroadcastAddress.Code(): netutil.IPv4bcast(),
// Since nearly all networks in the Internet currently support an MTU of
// 576 or greater, we strongly recommend the use of 576 for datagrams
// sent to non-local networks.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.3.
dhcpv4.Option{
Code: dhcpv4.OptionInterfaceMTU,
Value: dhcpv4.Uint16(576),
},
// Set the All Subnets Are Local Option to false since commonly the
// connected hosts aren't expected to be multihomed.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.3.
dhcpv4.OptGeneric(dhcpv4.OptionAllSubnetsAreLocal, []byte{0x00}),
// Set the Perform Mask Discovery Option to false to provide the subnet
// mask by options only.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.2.9.
dhcpv4.OptGeneric(dhcpv4.OptionPerformMaskDiscovery, []byte{0x00}),
// A system MUST NOT send an Address Mask Reply unless it is an
// authoritative agent for address masks. An authoritative agent may be
// a host or a gateway, but it MUST be explicitly configured as a
// address mask agent.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.2.9.
dhcpv4.OptGeneric(dhcpv4.OptionMaskSupplier, []byte{0x00}),
// Set the Perform Router Discovery Option to true as per Router
// Discovery Document.
//
// See https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
dhcpv4.OptGeneric(dhcpv4.OptionPerformRouterDiscovery, []byte{0x01}),
// The all-routers address is preferred wherever possible.
//
// See https://datatracker.ietf.org/doc/html/rfc1256#section-5.1.
dhcpv4.Option{
Code: dhcpv4.OptionRouterSolicitationAddress,
Value: dhcpv4.IP(netutil.IPv4allrouter()),
},
// Don't set the Static Routes Option since it should be set up by
// system administrator.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.3.1.2.
// A datagram with the destination address of limited broadcast will be
// received by every host on the connected physical network but will not
// be forwarded outside that network.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3.
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
// Link-Layer Per Interface
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
// If the system does not dynamically negotiate use of the trailer
// protocol on a per-destination basis, the default configuration MUST
// disable the protocol.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.1.
dhcpv4.OptGeneric(dhcpv4.OptionTrailerEncapsulation, []byte{0x00}),
// For proxy ARP situations, the timeout needs to be on the order of a
// minute.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.2.1.
dhcpv4.Option{
Code: dhcpv4.OptionArpCacheTimeout,
Value: dhcpv4.Duration(time.Minute),
},
// An Internet host that implements sending both the RFC-894 and the
// RFC-1042 encapsulations MUST provide a configuration switch to select
// which is sent, and this switch MUST default to RFC-894.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-2.3.3.
dhcpv4.OptGeneric(dhcpv4.OptionEthernetEncapsulation, []byte{0x00}),
// TCP Per Host
dhcpv4.OptionTCPKeepaliveInterval.Code(): dhcpv4.Duration(0).ToBytes(),
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
// A fixed value must be at least big enough for the Internet diameter,
// i.e., the longest possible path. A reasonable value is about twice
// the diameter, to allow for continued Internet growth.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.7.
dhcpv4.Option{
Code: dhcpv4.OptionDefaulTCPTTL,
Value: dhcpv4.Duration(60 * time.Second),
},
// The interval MUST be configurable and MUST default to no less than
// two hours.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-4.2.3.6.
dhcpv4.Option{
Code: dhcpv4.OptionTCPKeepaliveInterval,
Value: dhcpv4.Duration(2 * time.Hour),
},
// Unfortunately, some misbehaved TCP implementations fail to respond to
// a probe segment unless it contains data.
//
// See https://datatracker.ietf.org/doc/html/rfc1122#section-4.2.3.6.
dhcpv4.OptGeneric(dhcpv4.OptionTCPKeepaliveGarbage, []byte{0x01}),
// Values From Configuration
dhcpv4.OptRouter(s.conf.GatewayIP.AsSlice()),
dhcpv4.OptionRouter.Code(): netutil.CloneIP(conf.subnet.IP),
dhcpv4.OptionSubnetMask.Code(): dhcpv4.IPMask(conf.subnet.Mask).ToBytes(),
}
dhcpv4.OptSubnetMask(s.conf.SubnetMask.AsSlice()),
)
// Set values for explicitly configured options.
for i, o := range conf.Options {
opt, err := parseDHCPOption(o)
s.explicitOpts = dhcpv4.Options{}
for i, o := range s.conf.Options {
code, val, err := parseDHCPOption(o)
if err != nil {
log.Error("dhcpv4: bad option string at index %d: %s", i, err)
continue
}
opts.Update(opt)
s.explicitOpts.Update(dhcpv4.Option{Code: code, Value: val})
// Remove those from the implicit options.
delete(s.implicitOpts, code.Code())
}
return opts
log.Debug("dhcpv4: implicit options:\n%s", s.implicitOpts.Summary(nil))
log.Debug("dhcpv4: explicit options:\n%s", s.explicitOpts.Summary(nil))
if len(s.explicitOpts) == 0 {
s.explicitOpts = nil
}
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -7,171 +6,262 @@ import (
"fmt"
"net"
"testing"
"time"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseOpt(t *testing.T) {
testCases := []struct {
name string
in string
wantCode dhcpv4.OptionCode
wantVal dhcpv4.OptionValue
wantErrMsg string
wantOpt dhcpv4.Option
}{{
name: "hex_success",
in: "6 hex c0a80101c0a80102",
name: "hex_success",
in: "6 hex c0a80101c0a80102",
wantCode: dhcpv4.GenericOptionCode(6),
wantVal: dhcpv4.OptionGeneric{Data: []byte{
0xC0, 0xA8, 0x01, 0x01,
0xC0, 0xA8, 0x01, 0x02,
}},
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0xC0, 0xA8, 0x01, 0x01},
net.IP{0xC0, 0xA8, 0x01, 0x02},
),
}, {
name: "ip_success",
in: "6 ip 1.2.3.4",
wantCode: dhcpv4.GenericOptionCode(6),
wantVal: dhcpv4.IP(net.IP{0x01, 0x02, 0x03, 0x04}),
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0x01, 0x02, 0x03, 0x04},
),
}, {
name: "ip_fail_v6",
in: "6 ip ::1234",
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
wantOpt: dhcpv4.Option{},
}, {
name: "ips_success",
in: "6 ips 192.168.1.1,192.168.1.2",
name: "ips_success",
in: "6 ips 192.168.1.1,192.168.1.2",
wantCode: dhcpv4.GenericOptionCode(6),
wantVal: dhcpv4.IPs([]net.IP{
{0xC0, 0xA8, 0x01, 0x01},
{0xC0, 0xA8, 0x01, 0x02},
}),
wantErrMsg: "",
wantOpt: dhcpv4.OptDNS(
net.IP{0xC0, 0xA8, 0x01, 0x01},
net.IP{0xC0, 0xA8, 0x01, 0x02},
),
}, {
name: "text_success",
in: "252 text http://192.168.1.1/",
wantCode: dhcpv4.GenericOptionCode(252),
wantVal: dhcpv4.String("http://192.168.1.1/"),
wantErrMsg: "",
}, {
name: "del_success",
in: "61 del",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionClientIdentifier),
wantVal: dhcpv4.OptionGeneric{Data: nil},
wantErrMsg: "",
}, {
name: "bool_success",
in: "19 bool true",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionIPForwarding),
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x01}},
wantErrMsg: "",
}, {
name: "bool_success_false",
in: "19 bool F",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionIPForwarding),
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x00}},
wantErrMsg: "",
}, {
name: "dur_success",
in: "24 dur 2h5s",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionPathMTUAgingTimeout),
wantVal: dhcpv4.Duration(2*time.Hour + 5*time.Second),
wantErrMsg: "",
}, {
name: "u8_success",
in: "23 u8 64",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionDefaultIPTTL),
wantVal: dhcpv4.OptionGeneric{Data: []byte{0x40}},
wantErrMsg: "",
}, {
name: "u16_success",
in: "22 u16 1234",
wantCode: dhcpv4.GenericOptionCode(dhcpv4.OptionMaximumDatagramAssemblySize),
wantVal: dhcpv4.Uint16(1234),
wantErrMsg: "",
wantOpt: dhcpv4.OptGeneric(
dhcpv4.GenericOptionCode(252),
[]byte("http://192.168.1.1/"),
),
}, {
name: "bad_parts",
in: "6 ip",
wantErrMsg: `invalid option string "6 ip": need at least three fields`,
wantOpt: dhcpv4.Option{},
wantCode: nil,
wantVal: nil,
wantErrMsg: `invalid option string "6 ip": bad option format`,
}, {
name: "bad_code",
in: "256 ip 1.1.1.1",
name: "bad_code",
in: "256 ip 1.1.1.1",
wantCode: nil,
wantVal: nil,
wantErrMsg: `invalid option string "256 ip 1.1.1.1": parsing option code: ` +
`strconv.ParseUint: parsing "256": value out of range`,
wantOpt: dhcpv4.Option{},
}, {
name: "bad_type",
in: "6 bad 1.1.1.1",
wantCode: nil,
wantVal: nil,
wantErrMsg: `invalid option string "6 bad 1.1.1.1": unknown option type "bad"`,
wantOpt: dhcpv4.Option{},
}, {
name: "hex_error",
in: "6 hex ZZZ",
name: "hex_error",
in: "6 hex ZZZ",
wantCode: nil,
wantVal: nil,
wantErrMsg: `invalid option string "6 hex ZZZ": decoding hex: ` +
`encoding/hex: invalid byte: U+005A 'Z'`,
wantOpt: dhcpv4.Option{},
}, {
name: "ip_error",
in: "6 ip 1.2.3.x",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"6 ip 1.2.3.x\": bad ipv4 address \"1.2.3.x\"",
wantOpt: dhcpv4.Option{},
}, {
name: "ips_error",
in: "6 ips 192.168.1.1,192.168.1.x",
name: "ip_error_v6",
in: "6 ip ::1234",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"6 ip ::1234\": bad ipv4 address \"::1234\"",
}, {
name: "ips_error",
in: "6 ips 192.168.1.1,192.168.1.x",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"6 ips 192.168.1.1,192.168.1.x\": " +
"parsing ip at index 1: bad ipv4 address \"192.168.1.x\"",
wantOpt: dhcpv4.Option{},
}, {
name: "bool_error",
in: "19 bool yes",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"19 bool yes\": decoding bool: " +
"strconv.ParseBool: parsing \"yes\": invalid syntax",
}, {
name: "dur_error",
in: "24 dur 3y",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"24 dur 3y\": decoding dur: " +
"unmarshaling duration: time: unknown unit \"y\" in duration \"3y\"",
}, {
name: "u8_error",
in: "23 u8 256",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"23 u8 256\": decoding u8: " +
"strconv.ParseUint: parsing \"256\": value out of range",
}, {
name: "u16_error",
in: "23 u16 65536",
wantCode: nil,
wantVal: nil,
wantErrMsg: "invalid option string \"23 u16 65536\": decoding u16: " +
"strconv.ParseUint: parsing \"65536\": value out of range",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opt, err := parseDHCPOption(tc.in)
if tc.wantErrMsg != "" {
require.Error(t, err)
code, val, err := parseDHCPOption(tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
return
}
require.NoError(t, err)
assert.Equal(t, tc.wantOpt.Code.Code(), opt.Code.Code())
assert.Equal(t, tc.wantOpt.Value.ToBytes(), opt.Value.ToBytes())
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantVal, val)
})
}
}
func TestPrepareOptions(t *testing.T) {
allDefault := dhcpv4.Options{
dhcpv4.OptionNonLocalSourceRouting.Code(): []byte{0},
dhcpv4.OptionDefaultIPTTL.Code(): []byte{64},
dhcpv4.OptionPerformMaskDiscovery.Code(): []byte{0},
dhcpv4.OptionMaskSupplier.Code(): []byte{0},
dhcpv4.OptionPerformRouterDiscovery.Code(): []byte{1},
dhcpv4.OptionRouterSolicitationAddress.Code(): []byte{224, 0, 0, 2},
dhcpv4.OptionBroadcastAddress.Code(): []byte{255, 255, 255, 255},
dhcpv4.OptionTrailerEncapsulation.Code(): []byte{0},
dhcpv4.OptionEthernetEncapsulation.Code(): []byte{0},
dhcpv4.OptionTCPKeepaliveInterval.Code(): []byte{0, 0, 0, 0},
dhcpv4.OptionTCPKeepaliveGarbage.Code(): []byte{0},
}
oneIP, otherIP := net.IP{1, 2, 3, 4}, net.IP{5, 6, 7, 8}
testCases := []struct {
name string
opts []string
checks dhcpv4.Options
name string
wantExplicit dhcpv4.Options
opts []string
}{{
name: "all_default",
checks: allDefault,
name: "all_default",
wantExplicit: nil,
opts: nil,
}, {
name: "configured_ip",
wantExplicit: dhcpv4.OptionsFromList(
dhcpv4.OptBroadcastAddress(oneIP),
),
opts: []string{
fmt.Sprintf("%d ip %s", dhcpv4.OptionBroadcastAddress, oneIP),
},
checks: dhcpv4.Options{
dhcpv4.OptionBroadcastAddress.Code(): oneIP,
},
}, {
name: "configured_ips",
wantExplicit: dhcpv4.OptionsFromList(
dhcpv4.Option{
Code: dhcpv4.OptionDomainNameServer,
Value: dhcpv4.IPs{oneIP, otherIP},
},
),
opts: []string{
fmt.Sprintf("%d ips %s,%s", dhcpv4.OptionDomainNameServer, oneIP, otherIP),
},
checks: dhcpv4.Options{
dhcpv4.OptionDomainNameServer.Code(): append(oneIP, otherIP...),
},
}, {
name: "configured_bad",
name: "configured_bad",
wantExplicit: nil,
opts: []string{
"19 bool yes",
"24 dur 3y",
"23 u8 256",
"23 u16 65536",
"20 hex",
"23 hex abc",
"32 ips 1,2,3,4",
"28 256.256.256.256",
},
checks: allDefault,
}, {
name: "configured_del",
wantExplicit: dhcpv4.OptionsFromList(
dhcpv4.OptBroadcastAddress(nil),
),
opts: []string{
"28 del",
},
}, {
name: "rewritten_del",
wantExplicit: dhcpv4.OptionsFromList(
dhcpv4.OptBroadcastAddress(netutil.IPv4bcast()),
),
opts: []string{
"28 del",
"28 ip 255.255.255.255",
},
}, {
name: "configured_and_del",
wantExplicit: dhcpv4.OptionsFromList(
dhcpv4.Option{
Code: dhcpv4.OptionGeoConf,
Value: dhcpv4.String("cba"),
},
),
opts: []string{
"123 text abc",
"123 del",
"123 text cba",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opts := prepareOptions(V4ServerConf{
// Just to avoid nil pointer dereference.
subnet: &net.IPNet{},
s := &v4Server{
conf: &V4ServerConf{
Options: tc.opts,
})
for c, v := range tc.checks {
optVal := opts.Get(dhcpv4.GenericOptionCode(c))
require.NotNil(t, optVal)
},
}
assert.Len(t, optVal, len(v))
assert.Equal(t, v, optVal)
t.Run(tc.name, func(t *testing.T) {
s.prepareOptions()
assert.Equal(t, tc.wantExplicit, s.explicitOpts)
for c := range s.explicitOpts {
assert.NotContains(t, s.implicitOpts, c)
}
})
}

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package dhcpd

View File

@@ -65,39 +65,42 @@ func hwAddrToLinkLayerAddr(hwa net.HardwareAddr) (lla []byte, err error) {
}
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
// Data scheme:
//
// ICMPv6:
// type[1]
// code[1]
// chksum[2]
// body (RouterAdvertisement):
// Cur Hop Limit[1]
// Flags[1]: MO......
// Router Lifetime[2]
// Reachable Time[4]
// Retrans Timer[4]
// Option=Prefix Information(3):
// Type[1]
// Length * 8bytes[1]
// Prefix Length[1]
// Flags[1]: LA......
// Valid Lifetime[4]
// Preferred Lifetime[4]
// Reserved[4]
// Prefix[16]
// Option=MTU(5):
// Type[1]
// Length * 8bytes[1]
// Reserved[2]
// MTU[4]
// Option=Source link-layer address(1):
// Link-Layer Address[8/24]
// Option=Recursive DNS Server(25):
// Type[1]
// Length * 8bytes[1]
// Reserved[2]
// Lifetime[4]
// Addresses of IPv6 Recursive DNS Servers[16]
// ICMPv6:
// - type[1]
// - code[1]
// - chksum[2]
// - body (RouterAdvertisement):
// - Cur Hop Limit[1]
// - Flags[1]: MO......
// - Router Lifetime[2]
// - Reachable Time[4]
// - Retrans Timer[4]
// - Option=Prefix Information(3):
// - Type[1]
// - Length * 8bytes[1]
// - Prefix Length[1]
// - Flags[1]: LA......
// - Valid Lifetime[4]
// - Preferred Lifetime[4]
// - Reserved[4]
// - Prefix[16]
// - Option=MTU(5):
// - Type[1]
// - Length * 8bytes[1]
// - Reserved[2]
// - MTU[4]
// - Option=Source link-layer address(1):
// - Link-Layer Address[8/24]
// - Option=Recursive DNS Server(25):
// - Type[1]
// - Length * 8bytes[1]
// - Reserved[2]
// - Lifetime[4]
// - Addresses of IPv6 Recursive DNS Servers[16]
//
// TODO(a.garipov): Replace with an existing implementation from a dependency.
func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
var lla []byte
lla, err = hwAddrToLinkLayerAddr(params.sourceLinkLayerAddress)

View File

@@ -1,103 +0,0 @@
package dhcpd
import (
"net"
"time"
)
// DHCPServer - DHCP server interface
type DHCPServer interface {
// ResetLeases resets leases.
ResetLeases(leases []*Lease) (err error)
// GetLeases returns deep clones of the current leases.
GetLeases(flags GetLeasesFlags) (leases []*Lease)
// AddStaticLease - add a static lease
AddStaticLease(l *Lease) (err error)
// RemoveStaticLease - remove a static lease
RemoveStaticLease(l *Lease) (err error)
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
FindMACbyIP(ip net.IP) net.HardwareAddr
// WriteDiskConfig4 - copy disk configuration
WriteDiskConfig4(c *V4ServerConf)
// WriteDiskConfig6 - copy disk configuration
WriteDiskConfig6(c *V6ServerConf)
// Start - start server
Start() (err error)
// Stop - stop server
Stop() (err error)
getLeasesRef() []*Lease
}
// V4ServerConf - server configuration
type V4ServerConf struct {
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
GatewayIP net.IP `yaml:"gateway_ip" json:"gateway_ip"`
SubnetMask net.IP `yaml:"subnet_mask" json:"subnet_mask"`
// broadcastIP is the broadcasting address pre-calculated from the
// configured gateway IP and subnet mask.
broadcastIP net.IP
// The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart net.IP `yaml:"range_start" json:"range_start"`
RangeEnd net.IP `yaml:"range_end" json:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
// IP conflict detector: time (ms) to wait for ICMP reply
// 0: disable
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
// Custom Options.
//
// Option with arbitrary hexadecimal data:
// DEC_CODE hex HEX_DATA
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
//
// Option with IP data (only 1 IP is supported):
// DEC_CODE ip IP_ADDR
Options []string `yaml:"options" json:"-"`
ipRange *ipRange
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
// subnet contains the DHCP server's subnet. The IP is the IP of the
// gateway.
subnet *net.IPNet
// notify is a way to signal to other components that leases have
// change. notify must be called outside of locked sections, since the
// clients might want to get the new data.
//
// TODO(a.garipov): This is utter madness and must be refactored. It
// just begs for deadlock bugs and other nastiness.
notify func(uint32)
}
// V6ServerConf - server configuration
type V6ServerConf struct {
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
// The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte
RangeStart net.IP `yaml:"range_start" json:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
ipStart net.IP // starting IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
// Server calls this function when leases data changes
notify func(uint32)
}

View File

@@ -1,12 +0,0 @@
package dhcpd
import (
"time"
)
// Currently used defaults for ifaceDNSAddrs.
const (
defaultMaxAttempts int = 10
defaultBackoff time.Duration = 500 * time.Millisecond
)

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package dhcpd
@@ -9,15 +8,19 @@ import "net"
type winServer struct{}
func (s *winServer) ResetLeases(_ []*Lease) (err error) { return nil }
func (s *winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
func (s *winServer) getLeasesRef() []*Lease { return nil }
func (s *winServer) AddStaticLease(_ *Lease) (err error) { return nil }
func (s *winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
func (s *winServer) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return nil }
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {}
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {}
func (s *winServer) Start() (err error) { return nil }
func (s *winServer) Stop() (err error) { return nil }
func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil }
func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil }
// type check
var _ DHCPServer = winServer{}
func (winServer) ResetLeases(_ []*Lease) (err error) { return nil }
func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
func (winServer) getLeasesRef() []*Lease { return nil }
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
func (winServer) FindMACbyIP(_ net.IP) (mac net.HardwareAddr) { return nil }
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
func (winServer) Start() (err error) { return nil }
func (winServer) Stop() (err error) { return nil }
func v4Create(_ *V4ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
func v6Create(_ V6ServerConf) (s DHCPServer, err error) { return winServer{}, nil }

View File

@@ -1,558 +0,0 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"net"
"strings"
"testing"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func notify4(flags uint32) {
}
// defaultV4ServerConf returns the default configuration for *v4Server to use in
// tests.
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
}
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
// type of s is *v4Server.
func defaultSrv(t *testing.T) (s DHCPServer) {
t.Helper()
var err error
s, err = v4Create(defaultV4ServerConf())
require.NoError(t, err)
return s
}
func TestV4_AddRemove_static(t *testing.T) {
s := defaultSrv(t)
ls := s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
// Add static lease.
l := &Lease{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}
err := s.AddStaticLease(l)
require.NoError(t, err)
err = s.AddStaticLease(l)
assert.Error(t, err)
ls = s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
assert.True(t, ls[0].IsStatic())
// Try to remove static lease.
err = s.RemoveStaticLease(&Lease{
IP: net.IP{192, 168, 10, 110},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
})
assert.Error(t, err)
// Remove static lease.
err = s.RemoveStaticLease(l)
require.NoError(t, err)
ls = s.GetLeases(LeasesStatic)
assert.Empty(t, ls)
}
func TestV4_AddReplace(t *testing.T) {
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
dynLeases := []Lease{{
Hostname: "dynamic-1.local",
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}, {
Hostname: "dynamic-2.local",
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 151},
}}
for i := range dynLeases {
err := s.addLease(&dynLeases[i])
require.NoError(t, err)
}
stLeases := []*Lease{{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}, {
Hostname: "static-2.local",
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 152},
}}
for _, l := range stLeases {
err := s.AddStaticLease(l)
require.NoError(t, err)
}
ls := s.GetLeases(LeasesStatic)
require.Len(t, ls, 2)
for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.True(t, l.IsStatic())
}
}
func TestV4Server_Process_optionsPriority(t *testing.T) {
defaultIP := net.IP{192, 168, 1, 1}
knownIP := net.IP{1, 2, 3, 4}
// prepareSrv creates a *v4Server and sets the opt6IPs in the initial
// configuration of the server as the value for DHCP option 6.
prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
t.Helper()
conf := defaultV4ServerConf()
if len(opt6IPs) > 0 {
b := &strings.Builder{}
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
for _, ip := range opt6IPs[1:] {
stringutil.WriteToBuilder(b, ",", ip.String())
}
conf.Options = []string{b.String()}
}
ss, err := v4Create(conf)
require.NoError(t, err)
var ok bool
s, ok = ss.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{defaultIP}
return s
}
// checkResp creates a discovery message with DHCP option 6 requested amd
// asserts the response to contain wantIPs in this option.
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
t.Helper()
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionDomainNameServer,
))
require.NoError(t, err)
var resp *dhcpv4.DHCPv4
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
res := s.process(req, resp)
require.Equal(t, 1, res)
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
require.NotEmpty(t, o)
wantData := []byte{}
for _, ip := range wantIPs {
wantData = append(wantData, ip...)
}
assert.Equal(t, o, wantData)
}
t.Run("default", func(t *testing.T) {
s := prepareSrv(t, nil)
checkResp(t, s, []net.IP{defaultIP})
})
t.Run("explicitly_configured", func(t *testing.T) {
s := prepareSrv(t, []net.IP{knownIP, knownIP})
checkResp(t, s, []net.IP{knownIP, knownIP})
})
}
func TestV4StaticLease_Get(t *testing.T) {
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
l := &Lease{
Hostname: "static-1.local",
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150},
}
err := s.AddStaticLease(l)
require.NoError(t, err)
var req, resp *dhcpv4.DHCPv4
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
t.Run("discover", func(t *testing.T) {
req, err = dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionDomainNameServer,
))
require.NoError(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
// Don't continue if we got any errors in the previous subtest.
require.NoError(t, err)
t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
t.Run("request", func(t *testing.T) {
req, err = dhcpv4.NewRequestFromOffer(resp)
require.NoError(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.NoError(t, err)
t.Run("ack", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
dnsAddrs := resp.DNS()
require.Len(t, dnsAddrs, 1)
assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0]))
t.Run("check_lease", func(t *testing.T) {
ls := s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, mac, ls[0].HWAddr)
})
}
func TestV4DynamicLease_Get(t *testing.T) {
conf := defaultV4ServerConf()
conf.Options = []string{
"81 hex 303132",
"82 ip 1.2.3.4",
}
var err error
sIface, err := v4Create(conf)
require.NoError(t, err)
s, ok := sIface.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
var req, resp *dhcpv4.DHCPv4
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
t.Run("discover", func(t *testing.T) {
req, err = dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionFQDN,
dhcpv4.OptionRelayAgentInformation,
))
require.NoError(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
// Don't continue if we got any errors in the previous subtest.
require.NoError(t, err)
t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.Equal(t, s.conf.RangeStart, resp.YourIPAddr)
assert.Equal(t, s.conf.GatewayIP, resp.ServerIdentifier())
router := resp.Router()
require.Len(t, router, 1)
assert.Equal(t, s.conf.GatewayIP, router[0])
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options.Get(dhcpv4.OptionFQDN))
rai := resp.RelayAgentInfo()
require.NotNil(t, rai)
assert.Equal(t, net.IP{1, 2, 3, 4}, net.IP(rai.ToBytes()))
})
t.Run("request", func(t *testing.T) {
req, err = dhcpv4.NewRequestFromOffer(resp)
require.NoError(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.NoError(t, err)
t.Run("ack", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
router := resp.Router()
require.Len(t, router, 1)
assert.Equal(t, s.conf.GatewayIP, router[0])
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
dnsAddrs := resp.DNS()
require.Len(t, dnsAddrs, 1)
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease
t.Run("check_lease", func(t *testing.T) {
ls := s.GetLeases(LeasesDynamic)
require.Len(t, ls, 1)
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
assert.Equal(t, mac, ls[0].HWAddr)
})
}
func TestNormalizeHostname(t *testing.T) {
testCases := []struct {
name string
hostname string
wantErrMsg string
want string
}{{
name: "success",
hostname: "example.com",
wantErrMsg: "",
want: "example.com",
}, {
name: "success_empty",
hostname: "",
wantErrMsg: "",
want: "",
}, {
name: "success_spaces",
hostname: "my device 01",
wantErrMsg: "",
want: "my-device-01",
}, {
name: "success_underscores",
hostname: "my_device_01",
wantErrMsg: "",
want: "my-device-01",
}, {
name: "error_part",
hostname: "device !!!",
wantErrMsg: "",
want: "device",
}, {
name: "error_part_spaces",
hostname: "device ! ! !",
wantErrMsg: "",
want: "device",
}, {
name: "error",
hostname: "!!!",
wantErrMsg: `normalizing "!!!": no valid parts`,
want: "",
}, {
name: "error_spaces",
hostname: "! ! !",
wantErrMsg: `normalizing "! ! !": no valid parts`,
want: "",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := normalizeHostname(tc.hostname)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
}
// fakePacketConn is a mock implementation of net.PacketConn to simplify
// testing.
type fakePacketConn struct {
// writeTo is used to substitute net.PacketConn's WriteTo method.
writeTo func(p []byte, addr net.Addr) (n int, err error)
// net.PacketConn is embedded here simply to make *fakePacketConn a
// net.PacketConn without actually implementing all methods.
net.PacketConn
}
// WriteTo implements net.PacketConn interface for *fakePacketConn.
func (fc *fakePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return fc.writeTo(p, addr)
}
func TestV4Server_Send(t *testing.T) {
s := &v4Server{}
var (
defaultIP = net.IP{99, 99, 99, 99}
knownIP = net.IP{4, 2, 4, 2}
knownMAC = net.HardwareAddr{6, 5, 4, 3, 2, 1}
)
defaultPeer := &net.UDPAddr{
IP: defaultIP,
// Use neither client nor server port to check it actually
// changed.
Port: dhcpv4.ClientPort + dhcpv4.ServerPort,
}
defaultResp := &dhcpv4.DHCPv4{}
testCases := []struct {
want net.Addr
req *dhcpv4.DHCPv4
resp *dhcpv4.DHCPv4
name string
}{{
name: "giaddr",
req: &dhcpv4.DHCPv4{GatewayIPAddr: knownIP},
resp: defaultResp,
want: &net.UDPAddr{
IP: knownIP,
Port: dhcpv4.ServerPort,
},
}, {
name: "nak",
req: &dhcpv4.DHCPv4{},
resp: &dhcpv4.DHCPv4{
Options: dhcpv4.OptionsFromList(
dhcpv4.OptMessageType(dhcpv4.MessageTypeNak),
),
},
want: defaultPeer,
}, {
name: "ciaddr",
req: &dhcpv4.DHCPv4{ClientIPAddr: knownIP},
resp: &dhcpv4.DHCPv4{},
want: &net.UDPAddr{
IP: knownIP,
Port: dhcpv4.ClientPort,
},
}, {
name: "chaddr",
req: &dhcpv4.DHCPv4{ClientHWAddr: knownMAC},
resp: &dhcpv4.DHCPv4{YourIPAddr: knownIP},
want: &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: knownMAC},
yiaddr: knownIP,
},
}, {
name: "who_are_you",
req: &dhcpv4.DHCPv4{},
resp: &dhcpv4.DHCPv4{},
want: defaultPeer,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn := &fakePacketConn{
writeTo: func(_ []byte, addr net.Addr) (_ int, _ error) {
assert.Equal(t, tc.want, addr)
return 0, nil
},
}
s.send(cloneUDPAddr(defaultPeer), conn, tc.req, tc.resp)
})
}
t.Run("giaddr_nak", func(t *testing.T) {
req := &dhcpv4.DHCPv4{
GatewayIPAddr: knownIP,
}
// Ensure the request is for unicast.
req.SetUnicast()
resp := &dhcpv4.DHCPv4{
Options: dhcpv4.OptionsFromList(
dhcpv4.OptMessageType(dhcpv4.MessageTypeNak),
),
}
want := &net.UDPAddr{
IP: req.GatewayIPAddr,
Port: dhcpv4.ServerPort,
}
conn := &fakePacketConn{
writeTo: func(_ []byte, addr net.Addr) (n int, err error) {
assert.Equal(t, want, addr)
return 0, nil
},
}
s.send(cloneUDPAddr(defaultPeer), conn, req, resp)
assert.True(t, resp.IsBroadcast())
})
}

Some files were not shown because too many files have changed in this diff Show More