Merge branch 'master' into 5615-rm-raw
This commit is contained in:
@@ -1,47 +1,14 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// The maximum lengths of generated hostnames for different IP versions.
|
||||
const (
|
||||
ipv4HostnameMaxLen = len("192-168-100-100")
|
||||
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
|
||||
)
|
||||
|
||||
// generateIPv4Hostname generates the hostname by IP address version 4.
|
||||
func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv4HostnameMaxLen)
|
||||
for i, part := range ipv4 {
|
||||
if i > 0 {
|
||||
hnData = append(hnData, '-')
|
||||
}
|
||||
hnData = strconv.AppendUint(hnData, uint64(part), 10)
|
||||
}
|
||||
|
||||
return string(hnData)
|
||||
}
|
||||
|
||||
// generateIPv6Hostname generates the hostname by IP address version 6.
|
||||
func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
hnData := make([]byte, 0, ipv6HostnameMaxLen)
|
||||
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
|
||||
if i > 0 {
|
||||
hnData = append(hnData, '-')
|
||||
}
|
||||
for _, val := range ipv6[i*2 : i*2+2] {
|
||||
if val < 10 {
|
||||
hnData = append(hnData, '0')
|
||||
}
|
||||
hnData = strconv.AppendUint(hnData, uint64(val), 16)
|
||||
}
|
||||
}
|
||||
|
||||
return string(hnData)
|
||||
}
|
||||
|
||||
// GenerateHostname generates the hostname from ip. In case of using IPv4 the
|
||||
// result should be like:
|
||||
//
|
||||
@@ -52,10 +19,42 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
|
||||
// ff80-f076-0000-0000-0000-0000-0000-0010
|
||||
//
|
||||
// ip must be either an IPv4 or an IPv6.
|
||||
func GenerateHostname(ip net.IP) (hostname string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return generateIPv4Hostname(ipv4)
|
||||
func GenerateHostname(ip netip.Addr) (hostname string) {
|
||||
if !ip.IsValid() {
|
||||
// TODO(s.chzhen): Get rid of it.
|
||||
panic("aghnet generate hostname: invalid ip")
|
||||
}
|
||||
|
||||
return generateIPv6Hostname(ip)
|
||||
ip = ip.Unmap()
|
||||
hostname = ip.StringExpanded()
|
||||
|
||||
if ip.Is4() {
|
||||
return strings.Replace(hostname, ".", "-", -1)
|
||||
}
|
||||
|
||||
return strings.Replace(hostname, ":", "-", -1)
|
||||
}
|
||||
|
||||
// NewDomainNameSet returns nil and error, if list has duplicate or empty
|
||||
// domain name. Otherwise returns a set, which contains non-FQDN domain names,
|
||||
// and nil error.
|
||||
func NewDomainNameSet(list []string) (set *stringutil.Set, err error) {
|
||||
set = stringutil.NewSet()
|
||||
|
||||
for i, v := range list {
|
||||
host := strings.ToLower(strings.TrimSuffix(v, "."))
|
||||
// TODO(a.garipov): Think about ignoring empty (".") names in the
|
||||
// future.
|
||||
if host == "" {
|
||||
return nil, errors.Error("host name is empty")
|
||||
}
|
||||
|
||||
if set.Has(host) {
|
||||
return nil, fmt.Errorf("duplicate host name %q at index %d", host, i)
|
||||
}
|
||||
|
||||
set.Add(host)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,19 +12,19 @@ func TestGenerateHostName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
}{{
|
||||
name: "good_ipv4",
|
||||
want: "127-0-0-1",
|
||||
ip: net.IP{127, 0, 0, 1},
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
}, {
|
||||
name: "good_ipv6",
|
||||
want: "fe00-0000-0000-0000-0000-0000-0000-0001",
|
||||
ip: net.ParseIP("fe00::1"),
|
||||
ip: netip.MustParseAddr("fe00::1"),
|
||||
}, {
|
||||
name: "4to6",
|
||||
want: "1-2-3-4",
|
||||
ip: net.ParseIP("::ffff:1.2.3.4"),
|
||||
ip: netip.MustParseAddr("::ffff:1.2.3.4"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -36,29 +36,6 @@ func TestGenerateHostName(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
}{{
|
||||
name: "bad_ipv4",
|
||||
ip: net.IP{127, 0, 0, 1, 0},
|
||||
}, {
|
||||
name: "bad_ipv6",
|
||||
ip: net.IP{
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff,
|
||||
},
|
||||
}, {
|
||||
name: "nil",
|
||||
ip: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Panics(t, func() { GenerateHostname(tc.ip) })
|
||||
})
|
||||
}
|
||||
assert.Panics(t, func() { GenerateHostname(netip.Addr{}) })
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,8 @@ func normalizeIP(ip net.IP) net.IP {
|
||||
}
|
||||
|
||||
// Load lease table from DB
|
||||
//
|
||||
// TODO(s.chzhen): Decrease complexity.
|
||||
func (s *server) dbLoad() (err error) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
@@ -57,14 +60,15 @@ func (s *server) dbLoad() (err error) {
|
||||
for i := range obj {
|
||||
obj[i].IP = normalizeIP(obj[i].IP)
|
||||
|
||||
if !(len(obj[i].IP) == 4 || len(obj[i].IP) == 16) {
|
||||
ip, ok := netip.AddrFromSlice(obj[i].IP)
|
||||
if !ok {
|
||||
log.Info("dhcp: invalid IP: %s", obj[i].IP)
|
||||
continue
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
HWAddr: obj[i].HWAddr,
|
||||
IP: obj[i].IP,
|
||||
IP: ip,
|
||||
Hostname: obj[i].Hostname,
|
||||
Expiry: time.Unix(obj[i].Expiry, 0),
|
||||
}
|
||||
@@ -145,7 +149,7 @@ func (s *server) dbStore() (err error) {
|
||||
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
IP: l.IP.AsSlice(),
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
@@ -162,7 +166,7 @@ func (s *server) dbStore() (err error) {
|
||||
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
IP: l.IP.AsSlice(),
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
|
||||
@@ -41,13 +41,16 @@ type Lease struct {
|
||||
// of 1 means that this is a static lease.
|
||||
Expiry time.Time `json:"expires"`
|
||||
|
||||
Hostname string `json:"hostname"`
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
// Hostname of the client.
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// HWAddr is the physical hardware address (MAC address).
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
|
||||
// IP is the IP address leased to the client.
|
||||
//
|
||||
// TODO(a.garipov): Migrate leases.db and use netip.Addr.
|
||||
IP net.IP `json:"ip"`
|
||||
// TODO(a.garipov): Migrate leases.db.
|
||||
IP netip.Addr `json:"ip"`
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of l.
|
||||
@@ -60,7 +63,7 @@ func (l *Lease) Clone() (clone *Lease) {
|
||||
Expiry: l.Expiry,
|
||||
Hostname: l.Hostname,
|
||||
HWAddr: slices.Clone(l.HWAddr),
|
||||
IP: slices.Clone(l.IP),
|
||||
IP: l.IP,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -48,11 +48,11 @@ func TestDB(t *testing.T) {
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 100},
|
||||
IP: netip.MustParseAddr("192.168.10.100"),
|
||||
}, {
|
||||
Hostname: "static-2.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB},
|
||||
IP: net.IP{192, 168, 10, 101},
|
||||
IP: netip.MustParseAddr("192.168.10.101"),
|
||||
}}
|
||||
|
||||
srv4, ok := s.srv4.(*v4Server)
|
||||
@@ -96,7 +96,7 @@ func TestNormalizeLeases(t *testing.T) {
|
||||
|
||||
staticLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
IP: net.IP{0, 2, 3, 4},
|
||||
IP: netip.MustParseAddr("0.2.3.4"),
|
||||
}, {
|
||||
HWAddr: net.HardwareAddr{2, 2, 3, 4},
|
||||
}}
|
||||
|
||||
@@ -496,18 +496,18 @@ func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
var srv DHCPServer
|
||||
if ip4 := l.IP.To4(); ip4 != nil {
|
||||
l.IP = ip4
|
||||
if l.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
l.IP = l.IP.To16()
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
@@ -528,27 +528,22 @@ func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
if l.IP == nil {
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
if ip4 == nil {
|
||||
l.IP = l.IP.To16()
|
||||
|
||||
err = s.srv6.RemoveStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
|
||||
return
|
||||
var srv DHCPServer
|
||||
if l.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
l.IP = ip4
|
||||
err = s.srv4.RemoveStaticLease(l)
|
||||
err = srv.RemoveStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
|
||||
161
internal/dhcpd/http_unix_test.go
Normal file
161
internal/dhcpd/http_unix_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_handleDHCPStatus(t *testing.T) {
|
||||
const staticName = "static-client"
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
staticLease := &Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
WorkDir: t.TempDir(),
|
||||
DBFilePath: dbFilename,
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// checkStatus is a helper that asserts the response of
|
||||
// [*server.handleDHCPStatus].
|
||||
checkStatus := func(t *testing.T, want *dhcpStatusResponse) {
|
||||
w := httptest.NewRecorder()
|
||||
var req *http.Request
|
||||
req, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&want)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPStatus(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
assert.JSONEq(t, b.String(), w.Body.String())
|
||||
}
|
||||
|
||||
// defaultResponse is a helper that returs the response with default
|
||||
// configuration.
|
||||
defaultResponse := func() *dhcpStatusResponse {
|
||||
conf4 := defaultV4ServerConf()
|
||||
conf4.LeaseDuration = 86400
|
||||
|
||||
resp := &dhcpStatusResponse{
|
||||
V4: *conf4,
|
||||
V6: V6ServerConf{},
|
||||
Leases: []*Lease{},
|
||||
StaticLeases: []*Lease{},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
ok := t.Run("status", func(t *testing.T) {
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.StaticLeases = []*Lease{staticLease}
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_invalid_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
err = json.NewEncoder(b).Encode(&Lease{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("remove_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPRemoveStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("set_config", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.Enabled = false
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPSetConfig(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
//
|
||||
// TODO(a.garipov): Perhaps create an optimized version with uint32 for IPv4
|
||||
// ranges? Or use one of uint128 packages?
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr.
|
||||
type ipRange struct {
|
||||
start *big.Int
|
||||
end *big.Int
|
||||
@@ -27,8 +29,6 @@ const maxRangeLen = math.MaxUint32
|
||||
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
//
|
||||
// TODO(e.burkov): Use netip.Addr.
|
||||
func newIPRange(start, end net.IP) (r *ipRange, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ func normalizeHostname(hostname string) (norm string, err error) {
|
||||
// validHostnameForClient accepts the hostname sent by the client and its IP and
|
||||
// returns either a normalized version of that hostname, or a new hostname
|
||||
// generated from the IP address, or an empty string.
|
||||
func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostname string) {
|
||||
func (s *v4Server) validHostnameForClient(cliHostname string, ip netip.Addr) (hostname string) {
|
||||
hostname, err := normalizeHostname(cliHostname)
|
||||
if err != nil {
|
||||
log.Info("dhcpv4: %s", err)
|
||||
@@ -209,9 +209,8 @@ func (s *v4Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
netIP := ip.AsSlice()
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(netIP) {
|
||||
if l.IP == ip {
|
||||
if l.Expiry.After(now) || l.IsStatic() {
|
||||
return l.HWAddr
|
||||
}
|
||||
@@ -245,7 +244,8 @@ func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
s.leases = append(s.leases[:i], s.leases[i+1:]...)
|
||||
|
||||
r := s.conf.ipRange
|
||||
offset, ok := r.offset(l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
offset, ok := r.offset(leaseIP)
|
||||
if ok {
|
||||
s.leasedOffsets.set(offset, false)
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
for i, l := range s.leases {
|
||||
isStatic := l.IsStatic()
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP == lease.IP {
|
||||
if isStatic {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
@@ -289,13 +289,13 @@ const ErrDupHostname = errors.Error("hostname is not unique")
|
||||
// addLease adds a dynamic or static lease.
|
||||
func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
r := s.conf.ipRange
|
||||
offset, inOffset := r.offset(l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
offset, inOffset := r.offset(leaseIP)
|
||||
|
||||
if l.IsStatic() {
|
||||
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
|
||||
// disabled.
|
||||
addr := netip.AddrFrom4(*(*[4]byte)(l.IP.To4()))
|
||||
if sn := s.conf.subnet; !sn.Contains(addr) {
|
||||
if sn := s.conf.subnet; !sn.Contains(l.IP) {
|
||||
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
|
||||
}
|
||||
} else if !inOffset {
|
||||
@@ -323,7 +323,7 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
|
||||
}
|
||||
|
||||
for i, l := range s.leases {
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname {
|
||||
return fmt.Errorf("lease for ip %s is different: %+v", lease.IP, l)
|
||||
}
|
||||
@@ -350,10 +350,11 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
ip := l.IP.To4()
|
||||
if ip == nil {
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
if !l.IP.Is4() {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP == netip.AddrFrom4(*(*[4]byte)(ip)) {
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP == l.IP {
|
||||
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
|
||||
}
|
||||
|
||||
@@ -394,7 +395,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf(
|
||||
"removing dynamic leases for %s (%s): %w",
|
||||
ip,
|
||||
l.IP,
|
||||
l.HWAddr,
|
||||
err,
|
||||
)
|
||||
@@ -404,7 +405,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -427,7 +428,7 @@ func (s *v4Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
if len(l.IP) != 4 {
|
||||
if !l.IP.Is4() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -540,8 +541,8 @@ func (s *v4Server) findExpiredLease() int {
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
l = &Lease{HWAddr: slices.Clone(mac)}
|
||||
|
||||
l.IP = s.nextIP()
|
||||
if l.IP == nil {
|
||||
nextIP := s.nextIP()
|
||||
if nextIP == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil, nil
|
||||
@@ -552,6 +553,13 @@ func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
return s.leases[i], nil
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(nextIP)
|
||||
if !ok {
|
||||
return nil, errors.Error("invalid ip")
|
||||
}
|
||||
|
||||
l.IP = netIP
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -601,7 +609,8 @@ func (s *v4Server) allocateLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if s.addrAvailable(l.IP) {
|
||||
leaseIP := l.IP.AsSlice()
|
||||
if s.addrAvailable(leaseIP) {
|
||||
return l, nil
|
||||
}
|
||||
|
||||
@@ -621,8 +630,9 @@ func (s *v4Server) handleDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err error
|
||||
l = s.findLease(mac)
|
||||
if l != nil {
|
||||
reqIP := req.RequestedIPAddress()
|
||||
if len(reqIP) != 0 && !reqIP.Equal(l.IP) {
|
||||
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, l.IP)
|
||||
leaseIP := net.IP(l.IP.AsSlice())
|
||||
if len(reqIP) != 0 && !reqIP.Equal(leaseIP) {
|
||||
log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, leaseIP)
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
|
||||
@@ -672,12 +682,19 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Info("check lease: invalid IP: %s", ip)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, l := range s.leases {
|
||||
if !bytes.Equal(l.HWAddr, mac) {
|
||||
continue
|
||||
}
|
||||
|
||||
if l.IP.Equal(ip) {
|
||||
if l.IP == netIP {
|
||||
return l, false
|
||||
}
|
||||
|
||||
@@ -876,9 +893,16 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
reqIP = req.ClientIPAddr
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(reqIP)
|
||||
if !ok {
|
||||
log.Info("dhcpv4: invalid IP: %s", reqIP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var oldLease *Lease
|
||||
for _, l := range s.leases {
|
||||
if bytes.Equal(l.HWAddr, mac) && l.IP.Equal(reqIP) {
|
||||
if bytes.Equal(l.HWAddr, mac) && l.IP == netIP {
|
||||
oldLease = l
|
||||
|
||||
break
|
||||
@@ -918,8 +942,7 @@ func (s *v4Server) handleDecline(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
|
||||
log.Info("dhcpv4: changed ip from %s to %s for %s", reqIP, newLease.IP, mac)
|
||||
|
||||
resp.YourIPAddr = make([]byte, 4)
|
||||
copy(resp.YourIPAddr, newLease.IP)
|
||||
resp.YourIPAddr = net.IP(newLease.IP.AsSlice())
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
|
||||
|
||||
@@ -942,8 +965,15 @@ func (s *v4Server) handleRelease(req, resp *dhcpv4.DHCPv4) (err error) {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(reqIP)
|
||||
if !ok {
|
||||
log.Info("dhcpv4: invalid IP: %s", reqIP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, l := range s.leases {
|
||||
if !bytes.Equal(l.HWAddr, mac) || !l.IP.Equal(reqIP) {
|
||||
if !bytes.Equal(l.HWAddr, mac) || l.IP != netIP {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1016,7 +1046,7 @@ func (s *v4Server) handle(req, resp *dhcpv4.DHCPv4) int {
|
||||
}
|
||||
|
||||
if l != nil {
|
||||
resp.YourIPAddr = slices.Clone(l.IP)
|
||||
resp.YourIPAddr = net.IP(l.IP.AsSlice())
|
||||
}
|
||||
|
||||
s.updateOptions(req, resp)
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := net.IP{192, 168, 10, 10}
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
anotherIP := DefaultRangeStart
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
@@ -81,7 +81,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP.AsSlice(),
|
||||
IP: anotherIP,
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrDupHostname)
|
||||
})
|
||||
@@ -95,7 +95,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: staticMAC,
|
||||
IP: anotherIP.AsSlice(),
|
||||
IP: anotherIP,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
@@ -122,13 +122,14 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
discoverAnOffer := func(
|
||||
t *testing.T,
|
||||
name string,
|
||||
ip net.IP,
|
||||
netIP netip.Addr,
|
||||
mac net.HardwareAddr,
|
||||
) (resp *dhcpv4.DHCPv4) {
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.ResetLeases(s.GetLeases(LeasesStatic))
|
||||
})
|
||||
|
||||
ip := net.IP(netIP.AsSlice())
|
||||
req, err := dhcpv4.NewDiscovery(
|
||||
mac,
|
||||
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
|
||||
@@ -149,7 +150,7 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, staticName, anotherIP.AsSlice(), anotherMAC)
|
||||
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(staticName),
|
||||
@@ -159,11 +160,15 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
res := s4.handle(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
|
||||
var netIP netip.Addr
|
||||
netIP, ok = netip.AddrFromSlice(resp.YourIPAddr)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(netIP), resp.HostName())
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP.AsSlice(), staticMAC)
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
@@ -177,7 +182,8 @@ func TestV4Server_leasing(t *testing.T) {
|
||||
require.Len(t, fqdnOptData, 3+len(staticName))
|
||||
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
|
||||
|
||||
assert.Equal(t, staticIP, resp.YourIPAddr)
|
||||
ip := net.IP(staticIP.AsSlice())
|
||||
assert.Equal(t, ip, resp.YourIPAddr)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
@@ -210,7 +216,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "success.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
@@ -218,7 +224,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "probably-router.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: DefaultGatewayIP.AsSlice(),
|
||||
IP: DefaultGatewayIP,
|
||||
},
|
||||
name: "with_gateway_ip",
|
||||
wantErrMsg: "dhcpv4: adding static lease: " +
|
||||
@@ -227,7 +233,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "ip6.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.ParseIP("ffff::1"),
|
||||
IP: netip.MustParseAddr("ffff::1"),
|
||||
},
|
||||
name: "ipv6",
|
||||
wantErrMsg: `dhcpv4: adding static lease: ` +
|
||||
@@ -236,7 +242,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-mac.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "bad_mac",
|
||||
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
|
||||
@@ -245,7 +251,7 @@ func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-lbl-.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
},
|
||||
name: "bad_hostname",
|
||||
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
|
||||
@@ -287,11 +293,11 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
dynLeases := []Lease{{
|
||||
Hostname: "dynamic-1.local",
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}, {
|
||||
Hostname: "dynamic-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 151},
|
||||
IP: netip.MustParseAddr("192.168.10.151"),
|
||||
}}
|
||||
|
||||
for i := range dynLeases {
|
||||
@@ -302,11 +308,11 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
stLeases := []*Lease{{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}, {
|
||||
Hostname: "static-2.local",
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 152},
|
||||
IP: netip.MustParseAddr("192.168.10.152"),
|
||||
}}
|
||||
|
||||
for _, l := range stLeases {
|
||||
@@ -318,7 +324,7 @@ func TestV4_AddReplace(t *testing.T) {
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].IP, l.IP)
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.True(t, l.IsStatic())
|
||||
}
|
||||
@@ -511,7 +517,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
IP: netip.MustParseAddr("192.168.10.150"),
|
||||
}
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
@@ -537,7 +543,9 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.True(t, ip.Equal(resp.YourIPAddr))
|
||||
|
||||
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
|
||||
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
|
||||
@@ -562,7 +570,9 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.True(t, ip.Equal(resp.YourIPAddr))
|
||||
|
||||
assert.True(t, resp.Router()[0].Equal(s.conf.GatewayIP.AsSlice()))
|
||||
assert.True(t, resp.ServerIdentifier().Equal(s.conf.GatewayIP.AsSlice()))
|
||||
@@ -581,7 +591,7 @@ func TestV4StaticLease_Get(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.IP, ls[0].IP)
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
@@ -679,7 +689,8 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
|
||||
ip := netip.MustParseAddr("192.168.10.100")
|
||||
assert.Equal(t, ip, ls[0].IP)
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
@@ -860,3 +871,143 @@ func TestV4Server_Send(t *testing.T) {
|
||||
assert.True(t, resp.IsBroadcast())
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4Server_FindMACbyIP(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
anotherIP := netip.MustParseAddr("192.168.100.100")
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := &v4Server{
|
||||
leases: []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want net.HardwareAddr
|
||||
ip netip.Addr
|
||||
name string
|
||||
}{{
|
||||
name: "basic",
|
||||
ip: staticIP,
|
||||
want: staticMAC,
|
||||
}, {
|
||||
name: "not_found",
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "expired",
|
||||
ip: anotherIP,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "v6",
|
||||
ip: netip.MustParseAddr("ffff::1"),
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := s.FindMACbyIP(tc.ip)
|
||||
|
||||
require.Equal(t, tc.want, mac)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestV4Server_handleDecline(t *testing.T) {
|
||||
const (
|
||||
dynamicName = "dynamic-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
dynamicIP := netip.MustParseAddr("192.168.10.200")
|
||||
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s4.leases = []*Lease{{
|
||||
Hostname: dynamicName,
|
||||
HWAddr: dynamicMAC,
|
||||
IP: dynamicIP,
|
||||
}}
|
||||
|
||||
req, err := dhcpv4.New(
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
|
||||
req.ClientHWAddr = dynamicMAC
|
||||
|
||||
resp := &dhcpv4.DHCPv4{}
|
||||
err = s4.handleDecline(req, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResp := &dhcpv4.DHCPv4{
|
||||
YourIPAddr: net.IP(s4.conf.RangeStart.AsSlice()),
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
|
||||
),
|
||||
}
|
||||
|
||||
require.Equal(t, wantResp, resp)
|
||||
}
|
||||
|
||||
func TestV4Server_handleRelease(t *testing.T) {
|
||||
const (
|
||||
dynamicName = "dymamic-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
dynamicIP := netip.MustParseAddr("192.168.10.200")
|
||||
dynamicMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s4.leases = []*Lease{{
|
||||
Hostname: dynamicName,
|
||||
HWAddr: dynamicMAC,
|
||||
IP: dynamicIP,
|
||||
}}
|
||||
|
||||
req, err := dhcpv4.New(
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(net.IP(dynamicIP.AsSlice()))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.ClientIPAddr = net.IP(dynamicIP.AsSlice())
|
||||
req.ClientHWAddr = dynamicMAC
|
||||
|
||||
resp := &dhcpv4.DHCPv4{}
|
||||
err = s4.handleRelease(req, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResp := &dhcpv4.DHCPv4{
|
||||
Options: dhcpv4.OptionsFromList(
|
||||
dhcpv4.OptMessageType(dhcpv4.MessageTypeAck),
|
||||
),
|
||||
}
|
||||
|
||||
require.Equal(t, wantResp, resp)
|
||||
}
|
||||
|
||||
@@ -61,13 +61,13 @@ func ip6InRange(start, ip net.IP) bool {
|
||||
|
||||
// ResetLeases resets leases.
|
||||
func (s *v6Server) ResetLeases(leases []*Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
s.leases = nil
|
||||
for _, l := range leases {
|
||||
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
if l.Expiry.Unix() != leaseExpireStatic &&
|
||||
!ip6InRange(s.conf.ipStart, l.IP) {
|
||||
!ip6InRange(s.conf.ipStart, ip) {
|
||||
|
||||
log.Debug("dhcpv6: skipping a lease with IP %v: not within current IP range", l.IP)
|
||||
|
||||
@@ -119,9 +119,8 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
netIP := ip.AsSlice()
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(netIP) {
|
||||
if l.IP == ip {
|
||||
if l.Expiry.After(now) || l.IsStatic() {
|
||||
return l.HWAddr
|
||||
}
|
||||
@@ -133,7 +132,8 @@ func (s *v6Server) FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||
|
||||
// Remove (swap) lease by index
|
||||
func (s *v6Server) leaseRemoveSwapByIndex(i int) {
|
||||
s.ipAddrs[s.leases[i].IP[15]] = 0
|
||||
leaseIP := s.leases[i].IP.As16()
|
||||
s.ipAddrs[leaseIP[15]] = 0
|
||||
log.Debug("dhcpv6: removed lease %s", s.leases[i].HWAddr)
|
||||
|
||||
n := len(s.leases)
|
||||
@@ -162,7 +162,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != net.IPv6len {
|
||||
if !l.IP.Is6() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != 16 {
|
||||
if !l.IP.Is6() {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -234,14 +234,15 @@ func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
// Add a lease
|
||||
func (s *v6Server) addLease(l *Lease) {
|
||||
s.leases = append(s.leases, l)
|
||||
s.ipAddrs[l.IP[15]] = 1
|
||||
ip := l.IP.As16()
|
||||
s.ipAddrs[ip[15]] = 1
|
||||
log.Debug("dhcpv6: added lease %s <-> %s", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a lease with the same properties
|
||||
func (s *v6Server) rmLease(lease *Lease) (err error) {
|
||||
for i, l := range s.leases {
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
if l.IP == lease.IP {
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
|
||||
l.Hostname != lease.Hostname {
|
||||
return fmt.Errorf("lease not found")
|
||||
@@ -308,18 +309,27 @@ func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
copy(l.IP, s.conf.ipStart)
|
||||
l.IP = s.findFreeIP()
|
||||
if l.IP == nil {
|
||||
ip := s.findFreeIP()
|
||||
if ip == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy(s.leases[i].HWAddr, mac)
|
||||
|
||||
return s.leases[i]
|
||||
}
|
||||
|
||||
netIP, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.IP = netIP
|
||||
|
||||
s.addLease(&l)
|
||||
|
||||
return &l
|
||||
}
|
||||
|
||||
@@ -388,7 +398,8 @@ func (s *v6Server) checkIA(msg *dhcpv6.Message, lease *Lease) error {
|
||||
return fmt.Errorf("no IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
|
||||
if !oiaAddr.IPv6Addr.Equal(lease.IP) {
|
||||
leaseIP := net.IP(lease.IP.AsSlice())
|
||||
if !oiaAddr.IPv6Addr.Equal(leaseIP) {
|
||||
return fmt.Errorf("invalid IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
}
|
||||
@@ -475,7 +486,7 @@ func (s *v6Server) process(msg *dhcpv6.Message, req, resp dhcpv6.DHCPv6) bool {
|
||||
copy(oia.IaId[:], []byte(valueIAID))
|
||||
}
|
||||
oiaAddr := &dhcpv6.OptIAAddress{
|
||||
IPv6Addr: lease.IP,
|
||||
IPv6Addr: net.IP(lease.IP.AsSlice()),
|
||||
PreferredLifetime: lifetime,
|
||||
ValidLifetime: lifetime,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
@@ -27,7 +29,7 @@ func TestV6_AddRemove_static(t *testing.T) {
|
||||
|
||||
// Add static lease.
|
||||
l := &Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
err = s.AddStaticLease(l)
|
||||
@@ -46,7 +48,7 @@ func TestV6_AddRemove_static(t *testing.T) {
|
||||
|
||||
// Try to remove non-existent static lease.
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: net.ParseIP("2001::2"),
|
||||
IP: netip.MustParseAddr("2001::2"),
|
||||
HWAddr: l.HWAddr,
|
||||
})
|
||||
require.Error(t, err)
|
||||
@@ -71,10 +73,10 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
|
||||
// Add dynamic leases.
|
||||
dynLeases := []*Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::2"),
|
||||
IP: netip.MustParseAddr("2001::2"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
@@ -83,10 +85,10 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
}
|
||||
|
||||
stLeases := []*Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::3"),
|
||||
IP: netip.MustParseAddr("2001::3"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
@@ -99,7 +101,7 @@ func TestV6_AddReplace(t *testing.T) {
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].IP, l.IP)
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
|
||||
}
|
||||
@@ -126,7 +128,7 @@ func TestV6GetLease(t *testing.T) {
|
||||
}
|
||||
|
||||
l := &Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
IP: netip.MustParseAddr("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
err = s.AddStaticLease(l)
|
||||
@@ -158,7 +160,8 @@ func TestV6GetLease(t *testing.T) {
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.Equal(t, ip, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
@@ -182,7 +185,8 @@ func TestV6GetLease(t *testing.T) {
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
assert.Equal(t, ip, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
@@ -308,3 +312,72 @@ func TestIP6InRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestV6_FindMACbyIP(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("2001::1")
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
anotherIP := netip.MustParseAddr("2001::100")
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := &v6Server{
|
||||
leases: []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}},
|
||||
}
|
||||
|
||||
s.leases = []*Lease{{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
}, {
|
||||
Expiry: time.Unix(10, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
}}
|
||||
|
||||
testCases := []struct {
|
||||
want net.HardwareAddr
|
||||
ip netip.Addr
|
||||
name string
|
||||
}{{
|
||||
name: "basic",
|
||||
ip: staticIP,
|
||||
want: staticMAC,
|
||||
}, {
|
||||
name: "not_found",
|
||||
ip: netip.MustParseAddr("ffff::1"),
|
||||
want: nil,
|
||||
}, {
|
||||
name: "expired",
|
||||
ip: anotherIP,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "v4",
|
||||
ip: netip.MustParseAddr("1.2.3.4"),
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := s.FindMACbyIP(tc.ip)
|
||||
|
||||
require.Equal(t, tc.want, mac)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ type FilteringConfig struct {
|
||||
// EDNSClientSubnet is the settings list for EDNS Client Subnet.
|
||||
type EDNSClientSubnet struct {
|
||||
// CustomIP for EDNS Client Subnet.
|
||||
CustomIP string `yaml:"custom_ip"`
|
||||
CustomIP netip.Addr `yaml:"custom_ip"`
|
||||
|
||||
// Enabled defines if EDNS Client Subnet is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
@@ -340,15 +340,8 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
// TODO(s.chzhen): Add wrapper around netip.Addr.
|
||||
var ip net.IP
|
||||
ip, err = netutil.ParseIP(srvConf.EDNSClientSubnet.CustomIP)
|
||||
if err != nil {
|
||||
return conf, fmt.Errorf("edns: %w", err)
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
|
||||
conf.EDNSAddr = ip
|
||||
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
|
||||
}
|
||||
|
||||
if srvConf.CacheSize != 0 {
|
||||
@@ -377,7 +370,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
|
||||
err = s.prepareTLS(&conf)
|
||||
if err != nil {
|
||||
return conf, fmt.Errorf("validating tls: %w", err)
|
||||
return proxy.Config{}, fmt.Errorf("validating tls: %w", err)
|
||||
}
|
||||
|
||||
if c := srvConf.DNSCryptConfig; c.Enabled {
|
||||
@@ -388,7 +381,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
}
|
||||
|
||||
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
|
||||
return conf, errors.Error("no default upstream servers configured")
|
||||
return proxy.Config{}, errors.Error("no default upstream servers configured")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
|
||||
@@ -243,17 +243,16 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
|
||||
// Assume that we only process IPv4 now.
|
||||
//
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err)
|
||||
if !l.IP.Is4() {
|
||||
log.Debug("dnsforward: skipping invalid ip from dhcp: bad ipv4 net.IP %v", l.IP)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ipToHost[ip] = lowhost
|
||||
hostToIP[lowhost] = ip
|
||||
leaseIP := l.IP
|
||||
|
||||
ipToHost[leaseIP] = lowhost
|
||||
hostToIP[lowhost] = leaseIP
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
|
||||
@@ -1073,7 +1073,7 @@ var testDHCP = &dhcpd.MockInterface{
|
||||
OnEnabled: func() (ok bool) { return true },
|
||||
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
return []*dhcpd.Lease{{
|
||||
IP: net.IP{192, 168, 12, 34},
|
||||
IP: netip.MustParseAddr("192.168.12.34"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Hostname: "myhost",
|
||||
}}
|
||||
|
||||
@@ -23,26 +23,78 @@ import (
|
||||
)
|
||||
|
||||
// jsonDNSConfig is the JSON representation of the DNS server configuration.
|
||||
//
|
||||
// TODO(s.chzhen): Split it into smaller pieces. Use aghalg.NullBool instead
|
||||
// of *bool.
|
||||
type jsonDNSConfig struct {
|
||||
Upstreams *[]string `json:"upstream_dns"`
|
||||
UpstreamsFile *string `json:"upstream_dns_file"`
|
||||
Bootstraps *[]string `json:"bootstrap_dns"`
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
DNSSECEnabled *bool `json:"dnssec_enabled"`
|
||||
DisableIPv6 *bool `json:"disable_ipv6"`
|
||||
UpstreamMode *string `json:"upstream_mode"`
|
||||
CacheSize *uint32 `json:"cache_size"`
|
||||
CacheMinTTL *uint32 `json:"cache_ttl_min"`
|
||||
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
|
||||
CacheOptimistic *bool `json:"cache_optimistic"`
|
||||
ResolveClients *bool `json:"resolve_clients"`
|
||||
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
// Upstreams is the list of upstream DNS servers.
|
||||
Upstreams *[]string `json:"upstream_dns"`
|
||||
|
||||
// UpstreamsFile is the file containing upstream DNS servers.
|
||||
UpstreamsFile *string `json:"upstream_dns_file"`
|
||||
|
||||
// Bootstraps is the list of DNS servers resolving IP addresses of the
|
||||
// upstream DoH/DoT resolvers.
|
||||
Bootstraps *[]string `json:"bootstrap_dns"`
|
||||
|
||||
// ProtectionEnabled defines if protection is enabled.
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
|
||||
// RateLimit is the number of requests per second allowed per client.
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
|
||||
// BlockingMode defines the way blocked responses are constructed.
|
||||
BlockingMode *BlockingMode `json:"blocking_mode"`
|
||||
|
||||
// EDNSCSEnabled defines if EDNS Client Subnet is enabled.
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
|
||||
// EDNSCSUseCustom defines if EDNSCSCustomIP should be used.
|
||||
EDNSCSUseCustom *bool `json:"edns_cs_use_custom"`
|
||||
|
||||
// DNSSECEnabled defines if DNSSEC is enabled.
|
||||
DNSSECEnabled *bool `json:"dnssec_enabled"`
|
||||
|
||||
// DisableIPv6 defines if IPv6 addresses should be dropped.
|
||||
DisableIPv6 *bool `json:"disable_ipv6"`
|
||||
|
||||
// UpstreamMode defines the way DNS requests are constructed.
|
||||
UpstreamMode *string `json:"upstream_mode"`
|
||||
|
||||
// CacheSize in bytes.
|
||||
CacheSize *uint32 `json:"cache_size"`
|
||||
|
||||
// CacheMinTTL is custom minimum TTL for cached DNS responses.
|
||||
CacheMinTTL *uint32 `json:"cache_ttl_min"`
|
||||
|
||||
// CacheMaxTTL is custom maximum TTL for cached DNS responses.
|
||||
CacheMaxTTL *uint32 `json:"cache_ttl_max"`
|
||||
|
||||
// CacheOptimistic defines if expired entries should be served.
|
||||
CacheOptimistic *bool `json:"cache_optimistic"`
|
||||
|
||||
// ResolveClients defines if clients IPs should be resolved into hostnames.
|
||||
ResolveClients *bool `json:"resolve_clients"`
|
||||
|
||||
// UsePrivateRDNS defines if privates DNS resolvers should be used.
|
||||
UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"`
|
||||
|
||||
// LocalPTRUpstreams is the list of local private DNS resolvers.
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
|
||||
// BlockingIPv4 is custom IPv4 address for blocked A requests.
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
|
||||
// BlockingIPv6 is custom IPv6 address for blocked AAAA requests.
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
|
||||
// EDNSCSCustomIP is custom IP for EDNS Client Subnet.
|
||||
EDNSCSCustomIP netip.Addr `json:"edns_cs_custom_ip"`
|
||||
|
||||
// DefaultLocalPTRUpstreams is used to pass the addresses from
|
||||
// systemResolvers to the front-end. It's not a pointer to the slice since
|
||||
// there is no need to omit it while decoding from JSON.
|
||||
DefaultLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
@@ -57,7 +109,11 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
blockingIPv4 := s.conf.BlockingIPv4
|
||||
blockingIPv6 := s.conf.BlockingIPv6
|
||||
ratelimit := s.conf.Ratelimit
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
|
||||
useCustom := s.conf.EDNSClientSubnet.UseCustom
|
||||
|
||||
enableDNSSEC := s.conf.EnableDNSSEC
|
||||
aaaaDisabled := s.conf.AAAADisabled
|
||||
cacheSize := s.conf.CacheSize
|
||||
@@ -74,46 +130,40 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
return &jsonDNSConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
DNSSECEnabled: &enableDNSSEC,
|
||||
DisableIPv6: &aaaaDisabled,
|
||||
CacheSize: &cacheSize,
|
||||
CacheMinTTL: &cacheMinTTL,
|
||||
CacheMaxTTL: &cacheMaxTTL,
|
||||
CacheOptimistic: &cacheOptimistic,
|
||||
UpstreamMode: &upstreamMode,
|
||||
ResolveClients: &resolveClients,
|
||||
UsePrivateRDNS: &usePrivateRDNS,
|
||||
LocalPTRUpstreams: &localPTRUpstreams,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
|
||||
if err != nil {
|
||||
log.Debug("getting dns configuration: %s", err)
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
jsonDNSConfig
|
||||
// DefautLocalPTRUpstreams is used to pass the addresses from
|
||||
// systemResolvers to the front-end. It's not a pointer to the slice
|
||||
// since there is no need to omit it while decoding from JSON.
|
||||
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}{
|
||||
jsonDNSConfig: *s.getDNSConfig(),
|
||||
DefautLocalPTRUpstreams: defLocalPTRUps,
|
||||
return &jsonDNSConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: blockingIPv4,
|
||||
BlockingIPv6: blockingIPv6,
|
||||
RateLimit: &ratelimit,
|
||||
EDNSCSCustomIP: customIP,
|
||||
EDNSCSEnabled: &enableEDNSClientSubnet,
|
||||
EDNSCSUseCustom: &useCustom,
|
||||
DNSSECEnabled: &enableDNSSEC,
|
||||
DisableIPv6: &aaaaDisabled,
|
||||
CacheSize: &cacheSize,
|
||||
CacheMinTTL: &cacheMinTTL,
|
||||
CacheMaxTTL: &cacheMaxTTL,
|
||||
CacheOptimistic: &cacheOptimistic,
|
||||
UpstreamMode: &upstreamMode,
|
||||
ResolveClients: &resolveClients,
|
||||
UsePrivateRDNS: &usePrivateRDNS,
|
||||
LocalPTRUpstreams: &localPTRUpstreams,
|
||||
DefaultLocalPTRUpstreams: defLocalPTRUps,
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig()
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
@@ -204,6 +254,7 @@ func (req *jsonDNSConfig) checkCacheTTL() bool {
|
||||
return min <= max
|
||||
}
|
||||
|
||||
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := &jsonDNSConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
@@ -231,8 +282,8 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// setConfigRestartable sets the server parameters. shouldRestart is true if
|
||||
// the server should be restarted to apply changes.
|
||||
// setConfig sets the server parameters. shouldRestart is true if the server
|
||||
// should be restarted to apply changes.
|
||||
func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
@@ -250,6 +301,10 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr"
|
||||
}
|
||||
|
||||
if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom {
|
||||
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
|
||||
}
|
||||
|
||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||
@@ -281,6 +336,7 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile),
|
||||
setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.Enabled, dc.EDNSCSEnabled),
|
||||
setIfNotNil(&s.conf.EDNSClientSubnet.UseCustom, dc.EDNSCSUseCustom),
|
||||
setIfNotNil(&s.conf.CacheSize, dc.CacheSize),
|
||||
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
||||
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
||||
|
||||
@@ -181,6 +181,12 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "edns_cs_enabled",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "edns_cs_use_custom",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "edns_cs_use_custom_bad_ip",
|
||||
wantSet: "decoding request: ParseAddr(\"bad.ip\"): unexpected character (at \"bad.ip\")",
|
||||
}, {
|
||||
name: "dnssec_enabled",
|
||||
wantSet: "",
|
||||
@@ -222,16 +228,20 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
Req json.RawMessage `json:"req"`
|
||||
Want json.RawMessage `json:"want"`
|
||||
}
|
||||
loadTestData(t, t.Name()+jsonExt, &data)
|
||||
|
||||
testData := t.Name() + jsonExt
|
||||
loadTestData(t, testData, &data)
|
||||
|
||||
for _, tc := range testCases {
|
||||
// NOTE: Do not use require.Contains, because the size of the data
|
||||
// prevents it from printing a meaningful error message.
|
||||
caseData, ok := data[tc.name]
|
||||
require.True(t, ok)
|
||||
require.Truef(t, ok, "%q does not contain test data for test case %s", testData, tc.name)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
s.conf = defaultConf
|
||||
s.conf.FilteringConfig.EDNSClientSubnet.Enabled = false
|
||||
s.conf.FilteringConfig.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||
})
|
||||
|
||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||
|
||||
@@ -26,7 +26,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
},
|
||||
"fastest_addr": {
|
||||
"upstream_dns": [
|
||||
@@ -55,7 +57,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
},
|
||||
"parallel": {
|
||||
"upstream_dns": [
|
||||
@@ -84,6 +88,8 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"bootstraps": {
|
||||
@@ -66,7 +68,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"blocking_mode_good": {
|
||||
@@ -100,7 +104,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"blocking_mode_bad": {
|
||||
@@ -134,7 +140,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"ratelimit": {
|
||||
@@ -168,7 +176,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"edns_cs_enabled": {
|
||||
@@ -202,7 +212,85 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"edns_cs_use_custom": {
|
||||
"req": {
|
||||
"edns_cs_enabled": true,
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "1.2.3.4"
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": true,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "1.2.3.4"
|
||||
}
|
||||
},
|
||||
"edns_cs_use_custom_bad_ip": {
|
||||
"req": {
|
||||
"edns_cs_enabled": true,
|
||||
"edns_cs_use_custom": true,
|
||||
"edns_cs_custom_ip": "bad.ip"
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"protection_enabled": true,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"dnssec_enabled": {
|
||||
@@ -236,7 +324,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"cache_size": {
|
||||
@@ -270,7 +360,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_parallel": {
|
||||
@@ -304,7 +396,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_fastest_addr": {
|
||||
@@ -338,7 +432,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_dns_bad": {
|
||||
@@ -374,7 +470,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"bootstraps_bad": {
|
||||
@@ -410,7 +508,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"cache_bad_ttl": {
|
||||
@@ -445,7 +545,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"upstream_mode_bad": {
|
||||
@@ -479,7 +581,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_good": {
|
||||
@@ -517,7 +621,9 @@
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [
|
||||
"123.123.123.123"
|
||||
]
|
||||
],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_bad": {
|
||||
@@ -554,7 +660,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"local_ptr_upstreams_null": {
|
||||
@@ -588,7 +696,9 @@
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": []
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,6 +461,7 @@ func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
registerHTTP(http.MethodPut, "/control/safesearch/settings", d.handleSafeSearchSettings)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
|
||||
@@ -17,7 +17,7 @@ type SafeSearch interface {
|
||||
// SafeSearchConfig is a struct with safe search related settings.
|
||||
type SafeSearchConfig struct {
|
||||
// CustomResolver is the resolver used by safe search.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
CustomResolver Resolver `yaml:"-" json:"-"`
|
||||
|
||||
// Enabled indicates if safe search is enabled entirely.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
|
||||
@@ -1,29 +1,63 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
)
|
||||
|
||||
// TODO(d.kolyshev): Replace handlers below with the new API.
|
||||
|
||||
// handleSafeSearchEnable is the handler for POST /control/safesearch/enable
|
||||
// HTTP API.
|
||||
//
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, true)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleSafeSearchDisable is the handler for POST /control/safesearch/disable
|
||||
// HTTP API.
|
||||
//
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled, false)
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleSafeSearchStatus is the handler for GET /control/safesearch/status
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: protectedBool(&d.confLock, &d.Config.SafeSearchConf.Enabled),
|
||||
}
|
||||
var resp SafeSearchConfig
|
||||
func() {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
resp = d.Config.SafeSearchConf
|
||||
}()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) {
|
||||
req := &SafeSearchConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
d.Config.SafeSearchConf = *req
|
||||
}()
|
||||
|
||||
d.Config.ConfigModified()
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
@@ -859,15 +858,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(l.IP)
|
||||
if err != nil {
|
||||
log.Error("clients: bad client ip %v from dhcp: %s", l.IP, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(ipAddr, l.Hostname, ClientSourceDHCP)
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := &dhcpd.ServerConfig{
|
||||
|
||||
@@ -27,7 +27,8 @@ type clientJSON struct {
|
||||
// the allowlist.
|
||||
DisallowedRule *string `json:"disallowed_rule,omitempty"`
|
||||
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
|
||||
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -59,7 +60,7 @@ type clientListJSON struct {
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
|
||||
// respond with information about configured clients
|
||||
// handleGetClients is the handler for GET /control/clients HTTP API.
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
data := clientListJSON{}
|
||||
|
||||
@@ -88,32 +89,36 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// Convert JSON object to Client object
|
||||
// jsonToClient converts JSON object to Client object.
|
||||
func jsonToClient(cj clientJSON) (c *Client) {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
safeSearchConf := filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled}
|
||||
var safeSearchConf filtering.SafeSearchConfig
|
||||
if cj.SafeSearchConf != nil {
|
||||
safeSearchConf = *cj.SafeSearchConf
|
||||
} else {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
safeSearchConf = filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled}
|
||||
|
||||
// Set default service flags for enabled safesearch.
|
||||
if safeSearchConf.Enabled {
|
||||
safeSearchConf.Bing = true
|
||||
safeSearchConf.DuckDuckGo = true
|
||||
safeSearchConf.Google = true
|
||||
safeSearchConf.Pixabay = true
|
||||
safeSearchConf.Yandex = true
|
||||
safeSearchConf.YouTube = true
|
||||
// Set default service flags for enabled safesearch.
|
||||
if safeSearchConf.Enabled {
|
||||
safeSearchConf.Bing = true
|
||||
safeSearchConf.DuckDuckGo = true
|
||||
safeSearchConf.Google = true
|
||||
safeSearchConf.Pixabay = true
|
||||
safeSearchConf.Yandex = true
|
||||
safeSearchConf.YouTube = true
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
Name: cj.Name,
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
safeSearchConf: safeSearchConf,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
|
||||
Name: cj.Name,
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
safeSearchConf: safeSearchConf,
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
|
||||
@@ -121,7 +126,7 @@ func jsonToClient(cj clientJSON) (c *Client) {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Client object to JSON
|
||||
// clientToJSON converts Client object to JSON.
|
||||
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
@@ -136,6 +141,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
FilteringEnabled: c.FilteringEnabled,
|
||||
ParentalEnabled: c.ParentalEnabled,
|
||||
SafeSearchEnabled: safeSearchConf.Enabled,
|
||||
SafeSearchConf: safeSearchConf,
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
@@ -145,7 +151,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
// handleAddClient is the handler for POST /control/clients/add HTTP API.
|
||||
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
@@ -172,7 +178,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Remove client
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
@@ -202,7 +208,7 @@ type updateJSON struct {
|
||||
Data clientJSON `json:"data"`
|
||||
}
|
||||
|
||||
// Update client's properties
|
||||
// handleUpdateClient is the handler for POST /control/clients/update HTTP API.
|
||||
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
dj := updateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&dj)
|
||||
@@ -229,7 +235,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Get the list of clients by IP address list
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
|
||||
@@ -228,34 +228,32 @@ type tlsConfigSettings struct {
|
||||
}
|
||||
|
||||
type queryLogConfig struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
|
||||
// Interval is the interval for query log's files rotation.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are flushed
|
||||
// to disk.
|
||||
MemSize uint32 `yaml:"size_memory"`
|
||||
|
||||
// Enabled defines if the query log is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// FileEnabled defines, if the query log is written to the file.
|
||||
FileEnabled bool `yaml:"file_enabled"`
|
||||
|
||||
// Interval is the interval for query log's files rotation.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are
|
||||
// flushed to disk.
|
||||
MemSize uint32 `yaml:"size_memory"`
|
||||
|
||||
// Ignored is the list of host names, which should not be written to
|
||||
// log.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
}
|
||||
|
||||
type statsConfig struct {
|
||||
// Enabled defines if the statistics are enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Interval is the time interval for flushing statistics to the disk in
|
||||
// days.
|
||||
Interval uint32 `yaml:"interval"`
|
||||
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
|
||||
// Interval is the retention interval for statistics.
|
||||
Interval timeutil.Duration `yaml:"interval"`
|
||||
|
||||
// Enabled defines if the statistics are enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// config is the global configuration structure.
|
||||
@@ -286,7 +284,7 @@ var config = &configuration{
|
||||
CacheSize: 4 * 1024 * 1024,
|
||||
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{
|
||||
CustomIP: "",
|
||||
CustomIP: netip.Addr{},
|
||||
Enabled: false,
|
||||
UseCustom: false,
|
||||
},
|
||||
@@ -322,7 +320,7 @@ var config = &configuration{
|
||||
},
|
||||
Stats: statsConfig{
|
||||
Enabled: true,
|
||||
Interval: 1,
|
||||
Interval: timeutil.Duration{Duration: 1 * timeutil.Day},
|
||||
Ignored: []string{},
|
||||
},
|
||||
// NOTE: Keep these parameters in sync with the one put into
|
||||
@@ -503,7 +501,7 @@ func (c *configuration) write() (err error) {
|
||||
if Context.stats != nil {
|
||||
statsConf := stats.Config{}
|
||||
Context.stats.WriteDiskConfig(&statsConf)
|
||||
config.Stats.Interval = statsConf.LimitDays
|
||||
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
|
||||
config.Stats.Enabled = statsConf.Enabled
|
||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||
slices.Sort(config.Stats.Ignored)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -54,13 +52,13 @@ func initDNS() (err error) {
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.Stats.Interval,
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
}
|
||||
|
||||
set, err := nonDupEmptyHostNames(config.Stats.Ignored)
|
||||
set, err := aghnet.NewDomainNameSet(config.Stats.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("statistics: ignored list: %w", err)
|
||||
}
|
||||
@@ -84,13 +82,16 @@ func initDNS() (err error) {
|
||||
FileEnabled: config.QueryLog.FileEnabled,
|
||||
}
|
||||
|
||||
set, err = nonDupEmptyHostNames(config.QueryLog.Ignored)
|
||||
set, err = aghnet.NewDomainNameSet(config.QueryLog.Ignored)
|
||||
if err != nil {
|
||||
return fmt.Errorf("querylog: ignored list: %w", err)
|
||||
}
|
||||
|
||||
conf.Ignored = set
|
||||
Context.queryLog = querylog.New(conf)
|
||||
Context.queryLog, err = querylog.New(conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init querylog: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
if err != nil {
|
||||
@@ -535,30 +536,6 @@ func closeDNSServer() {
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
||||
// nonDupEmptyHostNames returns nil and error, if list has duplicate or empty
|
||||
// host name. Otherwise returns a set, which contains lowercase host names
|
||||
// without dot at the end, and nil error.
|
||||
func nonDupEmptyHostNames(list []string) (set *stringutil.Set, err error) {
|
||||
set = stringutil.NewSet()
|
||||
|
||||
for _, v := range list {
|
||||
host := strings.ToLower(strings.TrimSuffix(v, "."))
|
||||
// TODO(a.garipov): Think about ignoring empty (".") names in
|
||||
// the future.
|
||||
if host == "" {
|
||||
return nil, errors.Error("host name is empty")
|
||||
}
|
||||
|
||||
if set.Has(host) {
|
||||
return nil, fmt.Errorf("duplicate host name %q", host)
|
||||
}
|
||||
|
||||
set.Add(host)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
|
||||
// search.
|
||||
type safeSearchResolver struct{}
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 19
|
||||
const currentSchemaVersion = 20
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
@@ -92,6 +92,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema16to17,
|
||||
upgradeSchema17to18,
|
||||
upgradeSchema18to19,
|
||||
upgradeSchema19to20,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -1064,6 +1065,47 @@ func upgradeSchema18to19(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema19to20 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'statistics':
|
||||
// 'interval': 1
|
||||
//
|
||||
// # AFTER:
|
||||
// 'statistics':
|
||||
// 'interval': 24h
|
||||
func upgradeSchema19to20(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 19 to 20")
|
||||
diskConf["schema_version"] = 20
|
||||
|
||||
statsVal, ok := diskConf["statistics"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of stats: %T", statsVal)
|
||||
}
|
||||
|
||||
const field = "interval"
|
||||
|
||||
// Set the initial value from the global configuration structure.
|
||||
statsIvl := 1
|
||||
statsIvlVal, ok := stats[field]
|
||||
if ok {
|
||||
statsIvl, ok = statsIvlVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of %s: %T", field, statsIvlVal)
|
||||
}
|
||||
}
|
||||
|
||||
stats[field] = timeutil.Duration{Duration: time.Duration(statsIvl) * timeutil.Day}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -951,3 +951,98 @@ func TestUpgradeSchema18to19(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema19to20(t *testing.T) {
|
||||
testCases := []struct {
|
||||
ivl any
|
||||
want any
|
||||
wantErr string
|
||||
name string
|
||||
}{{
|
||||
ivl: 1,
|
||||
want: timeutil.Duration{Duration: timeutil.Day},
|
||||
wantErr: "",
|
||||
name: "success",
|
||||
}, {
|
||||
ivl: 0.25,
|
||||
want: 0,
|
||||
wantErr: "unexpected type of interval: float64",
|
||||
name: "fail",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
conf := yobj{
|
||||
"statistics": yobj{
|
||||
"interval": tc.ivl,
|
||||
},
|
||||
"schema_version": 19,
|
||||
}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema19to20(conf)
|
||||
|
||||
if tc.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErr, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conf["schema_version"], 20)
|
||||
|
||||
statsVal, ok := conf["statistics"]
|
||||
require.True(t, ok)
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var newIvl timeutil.Duration
|
||||
newIvl, ok = stats["interval"].(timeutil.Duration)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, tc.want, newIvl)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("no_stats", func(t *testing.T) {
|
||||
err := upgradeSchema19to20(yobj{})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("bad_stats", func(t *testing.T) {
|
||||
err := upgradeSchema19to20(yobj{
|
||||
"statistics": 0,
|
||||
})
|
||||
|
||||
testutil.AssertErrorMsg(t, "unexpected type of stats: int", err)
|
||||
})
|
||||
|
||||
t.Run("no_field", func(t *testing.T) {
|
||||
conf := yobj{
|
||||
"statistics": yobj{},
|
||||
}
|
||||
|
||||
err := upgradeSchema19to20(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
statsVal, ok := conf["statistics"]
|
||||
require.True(t, ok)
|
||||
|
||||
var stats yobj
|
||||
stats, ok = statsVal.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var ivl any
|
||||
ivl, ok = stats["interval"]
|
||||
require.True(t, ok)
|
||||
|
||||
var ivlVal timeutil.Duration
|
||||
ivlVal, ok = ivl.(timeutil.Duration)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
@@ -25,8 +27,8 @@ type configJSON struct {
|
||||
// fractional numbers and not mess the API users by changing the units.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if the querylog is enabled. It is an [aghalg.NullBool]
|
||||
// to be able to tell when it's set without using pointers.
|
||||
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
|
||||
// be able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
|
||||
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
|
||||
@@ -35,12 +37,39 @@ type configJSON struct {
|
||||
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
|
||||
}
|
||||
|
||||
// getConfigResp is the JSON structure for the querylog configuration.
|
||||
type getConfigResp struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
Ignored []string `json:"ignored"`
|
||||
|
||||
// Interval is the querylog rotation interval in milliseconds.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if the querylog is enabled. It is an aghalg.NullBool to
|
||||
// be able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
|
||||
// AnonymizeClientIP shows if the clients' IP addresses must be anonymized.
|
||||
// It is an aghalg.NullBool to be able to tell when it's set without using
|
||||
// pointers.
|
||||
//
|
||||
// TODO(a.garipov): Consider using separate setting for statistics.
|
||||
AnonymizeClientIP aghalg.NullBool `json:"anonymize_client_ip"`
|
||||
}
|
||||
|
||||
// Register web handlers
|
||||
func (l *queryLog) initWeb() {
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
|
||||
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
|
||||
l.conf.HTTPRegister(
|
||||
http.MethodPut,
|
||||
"/control/querylog/config/update",
|
||||
l.handlePutQueryLogConfig,
|
||||
)
|
||||
}
|
||||
|
||||
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -64,11 +93,41 @@ func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
|
||||
l.clear()
|
||||
}
|
||||
|
||||
// Get configuration
|
||||
// handleQueryLogInfo handles requests to the GET /control/querylog_info
|
||||
// endpoint.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
ivl := l.conf.RotationIvl
|
||||
|
||||
if !checkInterval(ivl) {
|
||||
// NOTE: If interval is custom we set it to 90 days for compatibility
|
||||
// with old API.
|
||||
ivl = timeutil.Day * 90
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, configJSON{
|
||||
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
|
||||
Interval: l.conf.RotationIvl.Hours() / 24,
|
||||
Interval: ivl.Hours() / 24,
|
||||
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetQueryLogConfig handles requests to the GET /control/querylog/config
|
||||
// endpoint.
|
||||
func (l *queryLog) handleGetQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
ignored := l.conf.Ignored.Values()
|
||||
slices.Sort(ignored)
|
||||
_ = aghhttp.WriteJSONResponse(w, r, getConfigResp{
|
||||
Ignored: ignored,
|
||||
Interval: float64(l.conf.RotationIvl.Milliseconds()),
|
||||
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
|
||||
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
|
||||
})
|
||||
}
|
||||
@@ -88,6 +147,8 @@ func AnonymizeIP(ip net.IP) {
|
||||
}
|
||||
|
||||
// handleQueryLogConfig handles the POST /control/querylog_config queries.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// Set NaN as initial value to be able to know if it changed later by
|
||||
// comparing it to NaN.
|
||||
@@ -103,6 +164,7 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
ivl := time.Duration(float64(timeutil.Day) * newConf.Interval)
|
||||
|
||||
hasIvl := !math.IsNaN(newConf.Interval)
|
||||
if hasIvl && !checkInterval(ivl) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "unsupported interval")
|
||||
@@ -115,8 +177,6 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
// Copy data, modify it, then activate. Other threads (readers) don't need
|
||||
// to use this lock.
|
||||
conf := *l.conf
|
||||
if newConf.Enabled != aghalg.NBNull {
|
||||
conf.Enabled = newConf.Enabled == aghalg.NBTrue
|
||||
@@ -138,6 +198,65 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
l.conf = &conf
|
||||
}
|
||||
|
||||
// handlePutQueryLogConfig handles the PUT /control/querylog/config/update
|
||||
// queries.
|
||||
func (l *queryLog) handlePutQueryLogConfig(w http.ResponseWriter, r *http.Request) {
|
||||
newConf := &getConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(newConf)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
set, err := aghnet.NewDomainNameSet(newConf.Ignored)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ivl := time.Duration(newConf.Interval) * time.Millisecond
|
||||
err = validateIvl(ivl)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if newConf.Enabled == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if newConf.AnonymizeClientIP == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "anonymize_client_ip is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer l.conf.ConfigModified()
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
conf := *l.conf
|
||||
|
||||
conf.Ignored = set
|
||||
conf.RotationIvl = ivl
|
||||
conf.Enabled = newConf.Enabled == aghalg.NBTrue
|
||||
|
||||
conf.AnonymizeClientIP = newConf.AnonymizeClientIP == aghalg.NBTrue
|
||||
if conf.AnonymizeClientIP {
|
||||
l.anonymizer.Store(AnonymizeIP)
|
||||
} else {
|
||||
l.anonymizer.Store(nil)
|
||||
}
|
||||
|
||||
l.conf = &conf
|
||||
}
|
||||
|
||||
// "value" -> value, return TRUE
|
||||
func getDoubleQuotesEnclosedValue(s *string) bool {
|
||||
t := *s
|
||||
|
||||
@@ -132,6 +132,20 @@ func checkInterval(ivl time.Duration) (ok bool) {
|
||||
return ivl == quarterDay || ivl == day || ivl == week || ivl == month || ivl == threeMonths
|
||||
}
|
||||
|
||||
// validateIvl returns an error if ivl is less than an hour or more than a
|
||||
// year.
|
||||
func validateIvl(ivl time.Duration) (err error) {
|
||||
if ivl < time.Hour {
|
||||
return errors.Error("less than an hour")
|
||||
}
|
||||
|
||||
if ivl > timeutil.Day*365 {
|
||||
return errors.Error("more than a year")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) WriteDiskConfig(c *Config) {
|
||||
*c = *l.conf
|
||||
}
|
||||
@@ -258,6 +272,9 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
|
||||
// ShouldLog returns true if request for the host should be logged.
|
||||
func (l *queryLog) ShouldLog(host string, _, _ uint16) bool {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
return !l.isIgnored(host)
|
||||
}
|
||||
|
||||
|
||||
@@ -22,13 +22,14 @@ func TestMain(m *testing.M) {
|
||||
// TestQueryLog tests adding and loading (with filtering) entries from disk and
|
||||
// memory.
|
||||
func TestQueryLog(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
@@ -125,12 +126,13 @@ func TestQueryLog(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
entNum = 10
|
||||
@@ -199,13 +201,14 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
@@ -227,13 +230,14 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryLogFileDisabled(t *testing.T) {
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: false,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 2,
|
||||
BaseDir: t.TempDir(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
@@ -254,13 +258,14 @@ func TestQueryLogShouldLog(t *testing.T) {
|
||||
)
|
||||
set := stringutil.NewSet(ignored1, ignored2)
|
||||
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
BaseDir: t.TempDir(),
|
||||
Ignored: set,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -9,9 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -135,12 +134,12 @@ func (p *AddParams) validate() (err error) {
|
||||
}
|
||||
|
||||
// New creates a new instance of the query log.
|
||||
func New(conf Config) (ql QueryLog) {
|
||||
func New(conf Config) (ql QueryLog, err error) {
|
||||
return newQueryLog(conf)
|
||||
}
|
||||
|
||||
// newQueryLog crates a new queryLog.
|
||||
func newQueryLog(conf Config) (l *queryLog) {
|
||||
func newQueryLog(conf Config) (l *queryLog, err error) {
|
||||
findClient := conf.FindClient
|
||||
if findClient == nil {
|
||||
findClient = func(_ []string) (_ *Client, _ error) {
|
||||
@@ -158,13 +157,10 @@ func newQueryLog(conf Config) (l *queryLog) {
|
||||
l.conf = &Config{}
|
||||
*l.conf = conf
|
||||
|
||||
if !checkInterval(conf.RotationIvl) {
|
||||
log.Info(
|
||||
"querylog: warning: unsupported rotation interval %s, setting to 1 day",
|
||||
conf.RotationIvl,
|
||||
)
|
||||
l.conf.RotationIvl = timeutil.Day
|
||||
err = validateIvl(conf.RotationIvl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unsupported interval: %w", err)
|
||||
}
|
||||
|
||||
return l
|
||||
return l, nil
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
l := newQueryLog(Config{
|
||||
l, err := newQueryLog(Config{
|
||||
FindClient: findClient,
|
||||
BaseDir: t.TempDir(),
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -44,6 +44,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
FileEnabled: true,
|
||||
AnonymizeClientIP: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(l.Close)
|
||||
|
||||
q := &dns.Msg{
|
||||
|
||||
@@ -7,8 +7,12 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// topAddrs is an alias for the types of the TopFoo fields of statsResponse.
|
||||
@@ -44,7 +48,7 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
defer s.lock.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
resp, ok := s.getData(s.limitHours)
|
||||
resp, ok := s.getData(uint32(s.limit.Hours()))
|
||||
log.Debug("stats: prepared data in %v", time.Since(start))
|
||||
|
||||
if !ok {
|
||||
@@ -63,20 +67,62 @@ type configResp struct {
|
||||
IntervalDays uint32 `json:"interval"`
|
||||
}
|
||||
|
||||
// getConfigResp is the response to the GET /control/stats_info.
|
||||
type getConfigResp struct {
|
||||
// Ignored is the list of host names, which should not be counted.
|
||||
Ignored []string `json:"ignored"`
|
||||
|
||||
// Interval is the statistics rotation interval in milliseconds.
|
||||
Interval float64 `json:"interval"`
|
||||
|
||||
// Enabled shows if statistics are enabled. It is an aghalg.NullBool to be
|
||||
// able to tell when it's set without using pointers.
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
// handleStatsInfo handles requests to the GET /control/stats_info endpoint.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
resp := configResp{IntervalDays: s.limitHours / 24}
|
||||
days := uint32(s.limit / timeutil.Day)
|
||||
ok := checkInterval(days)
|
||||
if !ok || (s.enabled && days == 0) {
|
||||
// NOTE: If interval is custom we set it to 90 days for compatibility
|
||||
// with old API.
|
||||
days = 90
|
||||
}
|
||||
|
||||
resp := configResp{IntervalDays: days}
|
||||
if !s.enabled {
|
||||
resp.IntervalDays = 0
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleGetStatsConfig handles requests to the GET /control/stats/config
|
||||
// endpoint.
|
||||
func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
ignored := s.ignored.Values()
|
||||
slices.Sort(ignored)
|
||||
|
||||
resp := getConfigResp{
|
||||
Ignored: ignored,
|
||||
Interval: float64(s.limit.Milliseconds()),
|
||||
Enabled: aghalg.BoolToNullBool(s.enabled),
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleStatsConfig handles requests to the POST /control/stats_config
|
||||
// endpoint.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := configResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
@@ -92,8 +138,55 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.setLimit(int(reqData.IntervalDays))
|
||||
s.configModified()
|
||||
defer s.configModified()
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
limit := time.Duration(reqData.IntervalDays) * timeutil.Day
|
||||
s.setLimit(limit)
|
||||
}
|
||||
|
||||
// handlePutStatsConfig handles requests to the PUT /control/stats/config/update
|
||||
// endpoint.
|
||||
func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := getConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
set, err := aghnet.NewDomainNameSet(reqData.Ignored)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ivl := time.Duration(reqData.Interval) * time.Millisecond
|
||||
err = validateIvl(ivl)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Enabled == aghalg.NBNull {
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer s.configModified()
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.ignored = set
|
||||
s.limit = ivl
|
||||
s.enabled = reqData.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
// handleStatsReset handles requests to the POST /control/stats_reset endpoint.
|
||||
@@ -114,4 +207,7 @@ func (s *StatsCtx) initWeb() {
|
||||
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
|
||||
s.httpRegister(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
|
||||
s.httpRegister(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
|
||||
}
|
||||
|
||||
152
internal/stats/http_test.go
Normal file
152
internal/stats/http_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleStatsConfig(t *testing.T) {
|
||||
const (
|
||||
smallIvl = 1 * time.Minute
|
||||
minIvl = 1 * time.Hour
|
||||
maxIvl = 365 * timeutil.Day
|
||||
)
|
||||
|
||||
conf := Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: time.Hour * 24,
|
||||
Enabled: true,
|
||||
UnitID: func() (id uint32) { return 0 },
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
body getConfigResp
|
||||
wantCode int
|
||||
wantErr string
|
||||
}{{
|
||||
name: "set_ivl_1_minIvl",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantErr: "",
|
||||
}, {
|
||||
name: "small_interval",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(smallIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "unsupported interval: less than an hour\n",
|
||||
}, {
|
||||
name: "big_interval",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(maxIvl.Milliseconds() + minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "unsupported interval: more than a year\n",
|
||||
}, {
|
||||
name: "set_ignored_ivl_1_maxIvl",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(maxIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"ignor.ed",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
wantErr: "",
|
||||
}, {
|
||||
name: "ignored_duplicate",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"ignor.ed",
|
||||
"ignor.ed",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "ignored: duplicate host name \"ignor.ed\" at index 1\n",
|
||||
}, {
|
||||
name: "ignored_empty",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBTrue,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{
|
||||
"",
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "ignored: host name is empty\n",
|
||||
}, {
|
||||
name: "enabled_is_null",
|
||||
body: getConfigResp{
|
||||
Enabled: aghalg.NBNull,
|
||||
Interval: float64(minIvl.Milliseconds()),
|
||||
Ignored: []string{},
|
||||
},
|
||||
wantCode: http.StatusUnprocessableEntity,
|
||||
wantErr: "enabled is null\n",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, err := New(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.Start()
|
||||
testutil.CleanupAndRequireSuccess(t, s.Close)
|
||||
|
||||
buf, err := json.Marshal(tc.body)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
configGet = "/control/stats/config"
|
||||
configPut = "/control/stats/config/update"
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, configPut, bytes.NewReader(buf))
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
s.handlePutStatsConfig(rw, req)
|
||||
require.Equal(t, tc.wantCode, rw.Code)
|
||||
|
||||
if tc.wantCode != http.StatusOK {
|
||||
assert.Equal(t, tc.wantErr, rw.Body.String())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := httptest.NewRequest(http.MethodGet, configGet, nil)
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
s.handleGetStatsConfig(rw, resp)
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
ans := getConfigResp{}
|
||||
err = json.Unmarshal(rw.Body.Bytes(), &ans)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.body, ans)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
@@ -25,6 +26,20 @@ func checkInterval(days uint32) (ok bool) {
|
||||
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
|
||||
}
|
||||
|
||||
// validateIvl returns an error if ivl is less than an hour or more than a
|
||||
// year.
|
||||
func validateIvl(ivl time.Duration) (err error) {
|
||||
if ivl < time.Hour {
|
||||
return errors.Error("less than an hour")
|
||||
}
|
||||
|
||||
if ivl > timeutil.Day*365 {
|
||||
return errors.Error("more than a year")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config is the configuration structure for the statistics collecting.
|
||||
type Config struct {
|
||||
// UnitID is the function to generate the identifier for current unit. If
|
||||
@@ -42,9 +57,8 @@ type Config struct {
|
||||
// Filename is the name of the database file.
|
||||
Filename string
|
||||
|
||||
// LimitDays is the maximum number of days to collect statistics into the
|
||||
// current unit.
|
||||
LimitDays uint32
|
||||
// Limit is an upper limit for collecting statistics.
|
||||
Limit time.Duration
|
||||
|
||||
// Enabled tells if the statistics are enabled.
|
||||
Enabled bool
|
||||
@@ -105,11 +119,8 @@ type StatsCtx struct {
|
||||
// enabled tells if the statistics are enabled.
|
||||
enabled bool
|
||||
|
||||
// limitHours is the maximum number of hours to collect statistics into the
|
||||
// current unit.
|
||||
//
|
||||
// TODO(s.chzhen): Rewrite to use time.Duration.
|
||||
limitHours uint32
|
||||
// limit is an upper limit for collecting statistics.
|
||||
limit time.Duration
|
||||
|
||||
// ignored is the list of host names, which should not be counted.
|
||||
ignored *stringutil.Set
|
||||
@@ -128,9 +139,14 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
httpRegister: conf.HTTPRegister,
|
||||
ignored: conf.Ignored,
|
||||
}
|
||||
if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
|
||||
s.limitHours = 24
|
||||
|
||||
err = validateIvl(conf.Limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unsupported interval: %w", err)
|
||||
}
|
||||
|
||||
s.limit = conf.Limit
|
||||
|
||||
if s.unitIDGen = newUnitID; conf.UnitID != nil {
|
||||
s.unitIDGen = conf.UnitID
|
||||
}
|
||||
@@ -150,7 +166,7 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
return nil, fmt.Errorf("stats: opening a transaction: %w", err)
|
||||
}
|
||||
|
||||
deleted := deleteOldUnits(tx, id-s.limitHours-1)
|
||||
deleted := deleteOldUnits(tx, id-uint32(s.limit.Hours())-1)
|
||||
udb = loadUnitFromDB(tx, id)
|
||||
|
||||
err = finishTxn(tx, deleted > 0)
|
||||
@@ -231,7 +247,7 @@ func (s *StatsCtx) Update(e Entry) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if !s.enabled || s.limitHours == 0 {
|
||||
if !s.enabled || s.limit == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -263,7 +279,7 @@ func (s *StatsCtx) WriteDiskConfig(dc *Config) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
dc.LimitDays = s.limitHours / 24
|
||||
dc.Limit = s.limit
|
||||
dc.Enabled = s.enabled
|
||||
dc.Ignored = s.ignored
|
||||
}
|
||||
@@ -273,7 +289,7 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
limit := s.limitHours
|
||||
limit := uint32(s.limit.Hours())
|
||||
if !s.enabled || limit == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -377,7 +393,7 @@ func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
limit := s.limitHours
|
||||
limit := uint32(s.limit.Hours())
|
||||
if limit == 0 || ptr.id == id {
|
||||
return true, time.Second
|
||||
}
|
||||
@@ -436,14 +452,14 @@ func (s *StatsCtx) periodicFlush() {
|
||||
log.Debug("periodic flushing finished")
|
||||
}
|
||||
|
||||
func (s *StatsCtx) setLimit(limitDays int) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if limitDays != 0 {
|
||||
// setLimit sets the limit. s.lock is expected to be locked.
|
||||
//
|
||||
// TODO(s.chzhen): Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) setLimit(limit time.Duration) {
|
||||
if limit != 0 {
|
||||
s.enabled = true
|
||||
s.limitHours = uint32(24 * limitDays)
|
||||
log.Debug("stats: set limit: %d days", limitDays)
|
||||
s.limit = limit
|
||||
log.Debug("stats: set limit: %d days", limit/timeutil.Day)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -35,9 +36,9 @@ func TestStats_races(t *testing.T) {
|
||||
var r uint32
|
||||
idGen := func() (id uint32) { return atomic.LoadUint32(&r) }
|
||||
conf := Config{
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
LimitDays: 1,
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
}
|
||||
|
||||
s, err := New(conf)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -51,10 +52,10 @@ func TestStats(t *testing.T) {
|
||||
|
||||
handlers := map[string]http.Handler{}
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
Enabled: true,
|
||||
UnitID: constUnitID,
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: constUnitID,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
@@ -158,7 +159,7 @@ func TestLargeNumbers(t *testing.T) {
|
||||
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
|
||||
|
||||
Reference in New Issue
Block a user