all: sync with master; upd chlog
This commit is contained in:
102
internal/aghalg/ringbuffer.go
Normal file
102
internal/aghalg/ringbuffer.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// RingBuffer is the implementation of ring buffer data structure.
|
||||
type RingBuffer[T any] struct {
|
||||
buf []T
|
||||
cur int
|
||||
full bool
|
||||
}
|
||||
|
||||
// NewRingBuffer initializes the new instance of ring buffer. size must be
|
||||
// greater or equal to zero.
|
||||
func NewRingBuffer[T any](size int) (rb *RingBuffer[T]) {
|
||||
if size < 0 {
|
||||
panic(errors.Error("ring buffer: size must be greater or equal to zero"))
|
||||
}
|
||||
|
||||
return &RingBuffer[T]{
|
||||
buf: make([]T, size),
|
||||
}
|
||||
}
|
||||
|
||||
// Append appends an element to the buffer.
|
||||
func (rb *RingBuffer[T]) Append(e T) {
|
||||
if len(rb.buf) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rb.buf[rb.cur] = e
|
||||
rb.cur = (rb.cur + 1) % cap(rb.buf)
|
||||
if rb.cur == 0 {
|
||||
rb.full = true
|
||||
}
|
||||
}
|
||||
|
||||
// Range calls cb for each element of the buffer. If cb returns false it stops.
|
||||
func (rb *RingBuffer[T]) Range(cb func(T) (cont bool)) {
|
||||
before, after := rb.splitCur()
|
||||
|
||||
for _, e := range before {
|
||||
if !cb(e) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range after {
|
||||
if !cb(e) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseRange calls cb for each element of the buffer in reverse order. If
|
||||
// cb returns false it stops.
|
||||
func (rb *RingBuffer[T]) ReverseRange(cb func(T) (cont bool)) {
|
||||
before, after := rb.splitCur()
|
||||
|
||||
for i := len(after) - 1; i >= 0; i-- {
|
||||
if !cb(after[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(before) - 1; i >= 0; i-- {
|
||||
if !cb(before[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitCur splits the buffer in two, before and after current position in
|
||||
// chronological order. If buffer is not full, after is nil.
|
||||
func (rb *RingBuffer[T]) splitCur() (before, after []T) {
|
||||
if len(rb.buf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cur := rb.cur
|
||||
if !rb.full {
|
||||
return rb.buf[:cur], nil
|
||||
}
|
||||
|
||||
return rb.buf[cur:], rb.buf[:cur]
|
||||
}
|
||||
|
||||
// Len returns a length of the buffer.
|
||||
func (rb *RingBuffer[T]) Len() (l int) {
|
||||
if !rb.full {
|
||||
return rb.cur
|
||||
}
|
||||
|
||||
return cap(rb.buf)
|
||||
}
|
||||
|
||||
// Clear clears the buffer.
|
||||
func (rb *RingBuffer[T]) Clear() {
|
||||
rb.full = false
|
||||
rb.cur = 0
|
||||
}
|
||||
173
internal/aghalg/ringbuffer_test.go
Normal file
173
internal/aghalg/ringbuffer_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package aghalg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// elements is a helper function that returns n elements of the buffer.
|
||||
func elements(b *aghalg.RingBuffer[int], n int, reverse bool) (es []int) {
|
||||
fn := b.Range
|
||||
if reverse {
|
||||
fn = b.ReverseRange
|
||||
}
|
||||
|
||||
i := 0
|
||||
fn(func(e int) (cont bool) {
|
||||
if i >= n {
|
||||
return false
|
||||
}
|
||||
|
||||
es = append(es, e)
|
||||
i++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return es
|
||||
}
|
||||
|
||||
func TestNewRingBuffer(t *testing.T) {
|
||||
t.Run("success_and_clear", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](5)
|
||||
for i := 0; i < 10; i++ {
|
||||
b.Append(i)
|
||||
}
|
||||
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
|
||||
|
||||
b.Clear()
|
||||
assert.Zero(t, b.Len())
|
||||
})
|
||||
|
||||
t.Run("negative_size", func(t *testing.T) {
|
||||
assert.PanicsWithError(t, "ring buffer: size must be greater or equal to zero", func() {
|
||||
aghalg.NewRingBuffer[int](-5)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](0)
|
||||
for i := 0; i < 10; i++ {
|
||||
b.Append(i)
|
||||
assert.Equal(t, 0, b.Len())
|
||||
assert.Empty(t, elements(b, b.Len(), false))
|
||||
assert.Empty(t, elements(b, b.Len(), true))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](1)
|
||||
for i := 0; i < 10; i++ {
|
||||
b.Append(i)
|
||||
assert.Equal(t, 1, b.Len())
|
||||
assert.Equal(t, []int{i}, elements(b, b.Len(), false))
|
||||
assert.Equal(t, []int{i}, elements(b, b.Len(), true))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRingBuffer_Range(t *testing.T) {
|
||||
const size = 5
|
||||
|
||||
b := aghalg.NewRingBuffer[int](size)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []int
|
||||
count int
|
||||
length int
|
||||
}{{
|
||||
name: "three",
|
||||
count: 3,
|
||||
length: 3,
|
||||
want: []int{0, 1, 2},
|
||||
}, {
|
||||
name: "ten",
|
||||
count: 10,
|
||||
length: size,
|
||||
want: []int{5, 6, 7, 8, 9},
|
||||
}, {
|
||||
name: "hundred",
|
||||
count: 100,
|
||||
length: size,
|
||||
want: []int{95, 96, 97, 98, 99},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for i := 0; i < tc.count; i++ {
|
||||
b.Append(i)
|
||||
}
|
||||
|
||||
bufLen := b.Len()
|
||||
assert.Equal(t, tc.length, bufLen)
|
||||
|
||||
want := tc.want
|
||||
assert.Equal(t, want, elements(b, bufLen, false))
|
||||
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, false))
|
||||
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, false))
|
||||
|
||||
want = want[:cap(want)]
|
||||
slices.Reverse(want)
|
||||
|
||||
assert.Equal(t, want, elements(b, bufLen, true))
|
||||
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, true))
|
||||
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, true))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBuffer_Range_increment(t *testing.T) {
|
||||
const size = 5
|
||||
|
||||
b := aghalg.NewRingBuffer[int](size)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []int
|
||||
}{{
|
||||
name: "one",
|
||||
want: []int{0},
|
||||
}, {
|
||||
name: "two",
|
||||
want: []int{0, 1},
|
||||
}, {
|
||||
name: "three",
|
||||
want: []int{0, 1, 2},
|
||||
}, {
|
||||
name: "four",
|
||||
want: []int{0, 1, 2, 3},
|
||||
}, {
|
||||
name: "five",
|
||||
want: []int{0, 1, 2, 3, 4},
|
||||
}, {
|
||||
name: "six",
|
||||
want: []int{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
name: "seven",
|
||||
want: []int{2, 3, 4, 5, 6},
|
||||
}, {
|
||||
name: "eight",
|
||||
want: []int{3, 4, 5, 6, 7},
|
||||
}, {
|
||||
name: "nine",
|
||||
want: []int{4, 5, 6, 7, 8},
|
||||
}, {
|
||||
name: "ten",
|
||||
want: []int{5, 6, 7, 8, 9},
|
||||
}}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
b.Append(i)
|
||||
|
||||
assert.Equal(t, tc.want, elements(b, b.Len(), false))
|
||||
|
||||
slices.Reverse(tc.want)
|
||||
assert.Equal(t, tc.want, elements(b, b.Len(), true))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
// Package aghio contains extensions for io package's types and methods
|
||||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
)
|
||||
|
||||
// LimitReachedError records the limit and the operation that caused it.
|
||||
type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements the [error] interface for *LimitReachedError.
|
||||
//
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
|
||||
}
|
||||
|
||||
// limitedReader is a wrapper for [io.Reader] limiting the input and dealing
|
||||
// with errors package.
|
||||
type limitedReader struct {
|
||||
r io.Reader
|
||||
limit int64
|
||||
n int64
|
||||
}
|
||||
|
||||
// Read implements the [io.Reader] interface.
|
||||
func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
if lr.n == 0 {
|
||||
return 0, &LimitReachedError{
|
||||
Limit: lr.limit,
|
||||
}
|
||||
}
|
||||
|
||||
p = p[:mathutil.Min(lr.n, int64(len(p)))]
|
||||
|
||||
n, err = lr.r.Read(p)
|
||||
lr.n -= int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after
|
||||
// n bytes read.
|
||||
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Error("limit must be non-negative")
|
||||
}
|
||||
|
||||
return &limitedReader{
|
||||
r: r,
|
||||
limit: n,
|
||||
n: n,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package aghio_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitReader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantErrMsg string
|
||||
name string
|
||||
n int64
|
||||
}{{
|
||||
wantErrMsg: "",
|
||||
name: "positive",
|
||||
n: 1,
|
||||
}, {
|
||||
wantErrMsg: "",
|
||||
name: "zero",
|
||||
n: 0,
|
||||
}, {
|
||||
wantErrMsg: "limit must be non-negative",
|
||||
name: "negative",
|
||||
n: -1,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := aghio.LimitReader(nil, tc.n)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReader_Read(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
||||
rStr string
|
||||
limit int64
|
||||
want int
|
||||
}{{
|
||||
err: nil,
|
||||
name: "perfectly_match",
|
||||
rStr: "abc",
|
||||
limit: 3,
|
||||
want: 3,
|
||||
}, {
|
||||
err: io.EOF,
|
||||
name: "eof",
|
||||
rStr: "",
|
||||
limit: 3,
|
||||
want: 0,
|
||||
}, {
|
||||
err: &aghio.LimitReachedError{
|
||||
Limit: 0,
|
||||
},
|
||||
name: "limit_reached",
|
||||
rStr: "abc",
|
||||
limit: 0,
|
||||
want: 0,
|
||||
}, {
|
||||
err: nil,
|
||||
name: "truncated",
|
||||
rStr: "abc",
|
||||
limit: 2,
|
||||
want: 2,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||
lreader, err := aghio.LimitReader(readCloser, tc.limit)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lreader)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := make([]byte, tc.limit+1)
|
||||
n, rerr := lreader.Read(buf)
|
||||
require.Equal(t, rerr, tc.err)
|
||||
|
||||
assert.Equal(t, tc.want, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
||||
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
|
||||
Limit: 0,
|
||||
})
|
||||
}
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
// listenPacketReusable announces on the local network address additionally
|
||||
// configuring the socket to have a reusable binding.
|
||||
func listenPacketReusable(ifaceName, network, address string) (c net.PacketConn, err error) {
|
||||
var port int
|
||||
var port uint16
|
||||
_, port, err = netutil.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect nclient4.NewRawUDPConn and implement here.
|
||||
return nclient4.NewRawUDPConn(ifaceName, port)
|
||||
return nclient4.NewRawUDPConn(ifaceName, int(port))
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
@@ -263,7 +264,7 @@ func IsAddrInUse(err error) (ok bool) {
|
||||
|
||||
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
|
||||
// addresses without port number.
|
||||
func CollectAllIfacesAddrs() (addrs []string, err error) {
|
||||
func CollectAllIfacesAddrs() (addrs []netip.Addr, err error) {
|
||||
var ifaceAddrs []net.Addr
|
||||
ifaceAddrs, err = netInterfaceAddrs()
|
||||
if err != nil {
|
||||
@@ -271,19 +272,41 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
|
||||
}
|
||||
|
||||
for _, addr := range ifaceAddrs {
|
||||
cidr := addr.String()
|
||||
var ip net.IP
|
||||
ip, _, err = net.ParseCIDR(cidr)
|
||||
var p netip.Prefix
|
||||
p, err = netip.ParsePrefix(addr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing cidr: %w", err)
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrs = append(addrs, ip.String())
|
||||
addrs = append(addrs, p.Addr())
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// ParseAddrPort parses an [netip.AddrPort] from s, which should be either a
|
||||
// valid IP, optionally with port, or a valid URL with plain IP address. The
|
||||
// defaultPort is used if s doesn't contain port number.
|
||||
func ParseAddrPort(s string, defaultPort uint16) (ipp netip.AddrPort, err error) {
|
||||
u, err := url.Parse(s)
|
||||
if err == nil && u.Host != "" {
|
||||
s = u.Host
|
||||
}
|
||||
|
||||
ipp, err = netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
ip, parseErr := netip.ParseAddr(s)
|
||||
if parseErr != nil {
|
||||
return ipp, errors.Join(err, parseErr)
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(ip, defaultPort), nil
|
||||
}
|
||||
|
||||
return ipp, nil
|
||||
}
|
||||
|
||||
// BroadcastFromPref calculates the broadcast IP address for p.
|
||||
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
|
||||
bc = p.Addr().Unmap()
|
||||
|
||||
@@ -230,7 +230,7 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
name string
|
||||
wantErrMsg string
|
||||
addrs []net.Addr
|
||||
wantAddrs []string
|
||||
wantAddrs []netip.Addr
|
||||
}{{
|
||||
name: "success",
|
||||
wantErrMsg: ``,
|
||||
@@ -241,10 +241,13 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
IP: net.IP{4, 3, 2, 1},
|
||||
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
||||
}},
|
||||
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
||||
wantAddrs: []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
netip.MustParseAddr("4.3.2.1"),
|
||||
},
|
||||
}, {
|
||||
name: "not_cidr",
|
||||
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
||||
wantErrMsg: `netip.ParsePrefix("1.2.3.4"): no '/'`,
|
||||
addrs: []net.Addr{&net.IPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
}},
|
||||
@@ -269,12 +272,11 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
|
||||
t.Run("internal_error", func(t *testing.T) {
|
||||
const errAddrs errors.Error = "can't get addresses"
|
||||
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
||||
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
||||
|
||||
_, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
assert.ErrorIs(t, err, errAddrs)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,16 @@ package aghnet_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -14,3 +20,76 @@ func TestMain(m *testing.M) {
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
func TestParseAddrPort(t *testing.T) {
|
||||
const defaultPort = 1
|
||||
|
||||
v4addr := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErrMsg string
|
||||
want netip.AddrPort
|
||||
}{{
|
||||
name: "success_ip",
|
||||
input: v4addr.String(),
|
||||
wantErrMsg: "",
|
||||
want: netip.AddrPortFrom(v4addr, defaultPort),
|
||||
}, {
|
||||
name: "success_ip_port",
|
||||
input: netutil.JoinHostPort(v4addr.String(), 5),
|
||||
wantErrMsg: "",
|
||||
want: netip.AddrPortFrom(v4addr, 5),
|
||||
}, {
|
||||
name: "success_url",
|
||||
input: (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: v4addr.String(),
|
||||
}).String(),
|
||||
wantErrMsg: "",
|
||||
want: netip.AddrPortFrom(v4addr, defaultPort),
|
||||
}, {
|
||||
name: "success_url_port",
|
||||
input: (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(v4addr.String(), 5),
|
||||
}).String(),
|
||||
wantErrMsg: "",
|
||||
want: netip.AddrPortFrom(v4addr, 5),
|
||||
}, {
|
||||
name: "error_invalid_ip",
|
||||
input: "256.256.256.256",
|
||||
wantErrMsg: `not an ip:port
|
||||
ParseAddr("256.256.256.256"): IPv4 field has value >255`,
|
||||
want: netip.AddrPort{},
|
||||
}, {
|
||||
name: "error_invalid_port",
|
||||
input: net.JoinHostPort(v4addr.String(), "-5"),
|
||||
wantErrMsg: `invalid port "-5" parsing "1.2.3.4:-5"
|
||||
ParseAddr("1.2.3.4:-5"): unexpected character (at ":-5")`,
|
||||
want: netip.AddrPort{},
|
||||
}, {
|
||||
name: "error_invalid_url",
|
||||
input: "tcp:://1.2.3.4",
|
||||
wantErrMsg: `invalid port "//1.2.3.4" parsing "tcp:://1.2.3.4"
|
||||
ParseAddr("tcp:://1.2.3.4"): each colon-separated field must have at least ` +
|
||||
`one digit (at "tcp:://1.2.3.4")`,
|
||||
want: netip.AddrPort{},
|
||||
}, {
|
||||
name: "empty",
|
||||
input: "",
|
||||
want: netip.AddrPort{},
|
||||
wantErrMsg: `not an ip:port
|
||||
ParseAddr(""): unable to parse IP`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ap, err := aghnet.ParseAddrPort(tc.input, defaultPort)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, ap)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
// DefaultRefreshIvl is the default period of time between refreshing cached
|
||||
// addresses.
|
||||
// const DefaultRefreshIvl = 5 * time.Minute
|
||||
|
||||
// HostGenFunc is the signature for functions generating fake hostnames. The
|
||||
// implementation must be safe for concurrent use.
|
||||
type HostGenFunc func() (host string)
|
||||
|
||||
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
|
||||
type SystemResolvers interface {
|
||||
// Get returns the slice of local resolvers' addresses. It must be safe for
|
||||
// concurrent use.
|
||||
Get() (rs []string)
|
||||
// refresh refreshes the local resolvers' addresses cache. It must be safe
|
||||
// for concurrent use.
|
||||
refresh() (err error)
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers, err error) {
|
||||
sr = newSystemResolvers(hostGenFunc)
|
||||
|
||||
// Fill cache.
|
||||
err = sr.refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// defaultHostGen is the default method of generating host for Refresh.
|
||||
func defaultHostGen() (host string) {
|
||||
// TODO(e.burkov): Use strings.Builder.
|
||||
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// systemResolvers is a default implementation of SystemResolvers interface.
|
||||
type systemResolvers struct {
|
||||
// addrsLock protects addrs.
|
||||
addrsLock sync.RWMutex
|
||||
// addrs is the set that contains cached local resolvers' addresses.
|
||||
addrs *stringutil.Set
|
||||
|
||||
// resolver is used to fetch the resolvers' addresses.
|
||||
resolver *net.Resolver
|
||||
// hostGenFunc generates hosts to resolve.
|
||||
hostGenFunc HostGenFunc
|
||||
}
|
||||
|
||||
const (
|
||||
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
|
||||
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
|
||||
|
||||
// errFakeDial is an error which dialFunc is expected to return.
|
||||
errFakeDial errors.Error = "this error signals the successful dialFunc work"
|
||||
|
||||
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
|
||||
// more than one percent sign.
|
||||
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
||||
)
|
||||
|
||||
// refresh implements the SystemResolvers interface for *systemResolvers.
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
|
||||
dnserr := &net.DNSError{}
|
||||
if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||
if hostGenFunc == nil {
|
||||
hostGenFunc = defaultHostGen
|
||||
}
|
||||
s := &systemResolvers{
|
||||
resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
hostGenFunc: hostGenFunc,
|
||||
addrs: stringutil.NewSet(),
|
||||
}
|
||||
s.resolver.Dial = s.dialFunc
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// validateDialedHost validated the host used by resolvers in dialFunc.
|
||||
func validateDialedHost(host string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
|
||||
|
||||
parts := strings.Split(host, "%")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// host
|
||||
case 2:
|
||||
// Remove the zone and check the IP address part.
|
||||
host = parts[0]
|
||||
default:
|
||||
return errUnexpectedHostFormat
|
||||
}
|
||||
|
||||
if _, err = netutil.ParseIP(host); err != nil {
|
||||
return errBadAddrPassed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dockerEmbeddedDNS is the address of Docker's embedded DNS server.
|
||||
//
|
||||
// See
|
||||
// https://github.com/moby/moby/blob/v1.12.0/docs/userguide/networking/dockernetworks.md.
|
||||
const dockerEmbeddedDNS = "127.0.0.11"
|
||||
|
||||
// dialFunc gets the resolver's address and puts it into internal cache.
|
||||
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
|
||||
// Just validate the passed address is a valid IP.
|
||||
var host string
|
||||
host, err = netutil.SplitHost(address)
|
||||
if err != nil {
|
||||
// TODO(e.burkov): Maybe use a structured errBadAddrPassed to
|
||||
// allow unwrapping of the real error.
|
||||
return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
|
||||
}
|
||||
|
||||
// Exclude Docker's embedded DNS server, as it may cause recursion if
|
||||
// the container is set as the host system's default DNS server.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3064.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps only do this when we are in the container?
|
||||
// Maybe use an environment variable?
|
||||
if host == dockerEmbeddedDNS {
|
||||
return nil, errFakeDial
|
||||
}
|
||||
|
||||
err = validateDialedHost(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating dialed host: %w", err)
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs.Add(host)
|
||||
|
||||
return nil, errFakeDial
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
return sr.addrs.Values()
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolversImpl(
|
||||
t *testing.T,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (imp *systemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
sr := createTestSystemResolvers(t, hostGenFunc)
|
||||
|
||||
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
t.Run("expected_error", func(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
assert.NoError(t, sr.refresh())
|
||||
})
|
||||
|
||||
t.Run("unexpected_error", func(t *testing.T) {
|
||||
_, err := NewSystemResolvers(func() string {
|
||||
return "127.0.0.1::123"
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
imp := createTestSystemResolversImpl(t, nil)
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
name string
|
||||
address string
|
||||
}{{
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv4",
|
||||
address: "127.0.0.1",
|
||||
}, {
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv6_port",
|
||||
address: "[::1]:53",
|
||||
}, {
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv6_zone_port",
|
||||
address: "[::1%lo0]:53",
|
||||
}, {
|
||||
want: errBadAddrPassed,
|
||||
name: "invalid_split_host",
|
||||
address: "127.0.0.1::123",
|
||||
}, {
|
||||
want: errUnexpectedHostFormat,
|
||||
name: "invalid_ipv6_zone_port",
|
||||
address: "[::1%%lo0]:53",
|
||||
}, {
|
||||
want: errBadAddrPassed,
|
||||
name: "invalid_parse_ip",
|
||||
address: "not-ip",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := imp.dialFunc(context.Background(), "", tc.address)
|
||||
require.Nil(t, conn)
|
||||
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolvers(
|
||||
t *testing.T,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
sr, err = NewSystemResolvers(hostGenFunc)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sr)
|
||||
|
||||
return sr
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Get(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
var rs []string
|
||||
require.NotPanics(t, func() {
|
||||
rs = sr.Get()
|
||||
})
|
||||
|
||||
assert.NotEmpty(t, rs)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Write tests for refreshWithTicker.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
||||
@@ -1,163 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// systemResolvers implementation differs for Windows since Go's resolver
|
||||
// doesn't work there.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/33097.
|
||||
type systemResolvers struct {
|
||||
// addrs is the slice of cached local resolvers' addresses.
|
||||
addrs []string
|
||||
addrsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
|
||||
return &systemResolvers{}
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
addrs := sr.addrs
|
||||
rs = make([]string, len(addrs))
|
||||
copy(rs, addrs)
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// writeExit writes "exit" to w and closes it. It is supposed to be run in
|
||||
// a goroutine.
|
||||
func writeExit(w io.WriteCloser) {
|
||||
defer log.OnPanic("systemResolvers: writeExit")
|
||||
|
||||
defer func() {
|
||||
derr := w.Close()
|
||||
if derr != nil {
|
||||
log.Error("systemResolvers: writeExit: closing: %s", derr)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.WriteString(w, "exit")
|
||||
if err != nil {
|
||||
log.Error("systemResolvers: writeExit: writing: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// scanAddrs scans the DNS addresses from nslookup's output. The expected
|
||||
// output of nslookup looks like this:
|
||||
//
|
||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||
// Address: 192.168.1.1
|
||||
func scanAddrs(s *bufio.Scanner) (addrs []string) {
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 || fields[0] != "Address:" {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the address contains port then it is separated with '#'.
|
||||
ipPort := strings.Split(fields[1], "#")
|
||||
if len(ipPort) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addr := ipPort[0]
|
||||
if net.ParseIP(addr) == nil {
|
||||
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
||||
//
|
||||
// TODO(e.burkov): This whole function needs more detailed research on getting
|
||||
// local resolvers addresses on Windows. We execute the external command for
|
||||
// now that is not the most accurate way.
|
||||
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||
var cmdPath string
|
||||
cmdPath, err = exec.LookPath("nslookup.exe")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("looking up cmd path: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(cmdPath)
|
||||
|
||||
var stdin io.WriteCloser
|
||||
stdin, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
var stdout io.ReadCloser
|
||||
stdout, err = cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
go writeExit(stdin)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start command executing: %w", err)
|
||||
}
|
||||
|
||||
s := bufio.NewScanner(stdout)
|
||||
addrs = scanAddrs(s)
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("executing the command: %w", err)
|
||||
}
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning output: %w", err)
|
||||
}
|
||||
|
||||
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
||||
//
|
||||
// See go doc os/exec.Cmd.StdoutPipe.
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
got, err := sr.getAddrs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't get addresses: %w", err)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs = got
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
// TODO(e.burkov): Write tests for Windows implementation.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -72,11 +72,10 @@ func TestLargestLabeled(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("scanner_fail", func(t *testing.T) {
|
||||
lr, err := aghio.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0)
|
||||
require.NoError(t, err)
|
||||
lr := ioutil.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0)
|
||||
|
||||
target := &aghio.LimitReachedError{}
|
||||
_, _, err = parsePSOutput(lr, "", nil)
|
||||
target := &ioutil.LimitError{}
|
||||
_, _, err := parsePSOutput(lr, "", nil)
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.EqualValues(t, 0, target.Limit)
|
||||
|
||||
@@ -4,10 +4,14 @@ package aghtest
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -51,3 +55,19 @@ func HostToIPs(host string) (ipv4, ipv6 netip.Addr) {
|
||||
|
||||
return netip.AddrFrom4([4]byte(hash[:4])), netip.AddrFrom16([16]byte(hash[4:20]))
|
||||
}
|
||||
|
||||
// StartHTTPServer is a helper that starts the HTTP server, which is configured
|
||||
// to return data on every request, and returns the client and server URL.
|
||||
func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write(data)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
return srv.Client(), u
|
||||
}
|
||||
|
||||
@@ -57,6 +57,9 @@ type DHCPServer interface {
|
||||
// RemoveStaticLease - remove a static lease
|
||||
RemoveStaticLease(l *Lease) (err error)
|
||||
|
||||
// UpdateStaticLease updates IP, hostname of the lease.
|
||||
UpdateStaticLease(l *Lease) (err error)
|
||||
|
||||
// FindMACbyIP returns a MAC address by the IP address of its lease, if
|
||||
// there is one.
|
||||
FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
|
||||
@@ -5,6 +5,7 @@ package dhcpd
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -290,12 +291,12 @@ func (s *server) handleDHCPSetConfigV6(
|
||||
func (s *server) createServers(conf *dhcpServerConfigJSON) (srv4, srv6 DHCPServer, err error) {
|
||||
srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("bad dhcpv4 configuration: %s", err)
|
||||
return nil, nil, fmt.Errorf("bad dhcpv4 configuration: %w", err)
|
||||
}
|
||||
|
||||
srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("bad dhcpv6 configuration: %s", err)
|
||||
return nil, nil, fmt.Errorf("bad dhcpv6 configuration: %w", err)
|
||||
}
|
||||
|
||||
if conf.Enabled == aghalg.NBTrue && !v4Enabled && !v6Enabled {
|
||||
@@ -424,7 +425,7 @@ func newNetInterfaceJSON(iface net.Interface) (out *netInterfaceJSON, err error)
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to get addresses for interface %s: %s",
|
||||
"failed to get addresses for interface %s: %w",
|
||||
iface.Name,
|
||||
err,
|
||||
)
|
||||
@@ -590,82 +591,78 @@ func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
// parseLease parses a lease from r. If there is no error returns DHCPServer
|
||||
// and *Lease. r must be non-nil.
|
||||
func (s *server) parseLease(r io.Reader) (srv DHCPServer, lease *Lease, err error) {
|
||||
l := &leaseStatic{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
err = json.NewDecoder(r).Decode(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
return nil, nil, fmt.Errorf("decoding json: %w", err)
|
||||
}
|
||||
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
return nil, nil, errors.Error("invalid ip")
|
||||
}
|
||||
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
var srv DHCPServer
|
||||
if l.IP.Is4() {
|
||||
lease, err = l.toLease()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parsing: %w", err)
|
||||
}
|
||||
|
||||
if lease.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
lease, err := l.toLease()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
|
||||
return srv, lease, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = srv.AddStaticLease(lease)
|
||||
// handleDHCPAddStaticLease is the handler for the POST
|
||||
// /control/dhcp/add_static_lease HTTP API.
|
||||
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
srv, lease, err := s.parseLease(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = srv.AddStaticLease(lease); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDHCPRemoveStaticLease is the handler for the POST
|
||||
// /control/dhcp/remove_static_lease HTTP API.
|
||||
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &leaseStatic{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !l.IP.IsValid() {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l.IP = l.IP.Unmap()
|
||||
|
||||
var srv DHCPServer
|
||||
if l.IP.Is4() {
|
||||
srv = s.srv4
|
||||
} else {
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
lease, err := l.toLease()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = srv.RemoveStaticLease(lease)
|
||||
srv, lease, err := s.parseLease(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = srv.RemoveStaticLease(lease); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleDHCPUpdateStaticLease is the handler for the POST
|
||||
// /control/dhcp/update_static_lease HTTP API.
|
||||
func (s *server) handleDHCPUpdateStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
srv, lease, err := s.parseLease(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = srv.UpdateStaticLease(lease); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -729,6 +726,7 @@ func (s *server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
|
||||
}
|
||||
|
||||
319
internal/dhcpd/http_unix_internal_test.go
Normal file
319
internal/dhcpd/http_unix_internal_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// defaultResponse is a helper that returns the response with default
|
||||
// configuration.
|
||||
func defaultResponse() *dhcpStatusResponse {
|
||||
conf4 := defaultV4ServerConf()
|
||||
conf4.LeaseDuration = 86400
|
||||
|
||||
resp := &dhcpStatusResponse{
|
||||
V4: *conf4,
|
||||
V6: V6ServerConf{},
|
||||
Leases: []*leaseDynamic{},
|
||||
StaticLeases: []*leaseStatic{},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// handleLease is the helper function that calls handler with provided static
|
||||
// lease as body and returns modified response recorder.
|
||||
func handleLease(t *testing.T, lease *leaseStatic, handler http.HandlerFunc) (w *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err := json.NewEncoder(b).Encode(lease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler(w, r)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// checkStatus is a helper that asserts the response of
|
||||
// [*server.handleDHCPStatus].
|
||||
func checkStatus(t *testing.T, s *server, want *dhcpStatusResponse) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err := json.NewEncoder(b).Encode(&want)
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPStatus(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
assert.JSONEq(t, b.String(), w.Body.String())
|
||||
}
|
||||
|
||||
func TestServer_handleDHCPStatus(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
staticMAC = "aa:aa:aa:aa:aa:aa"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
|
||||
staticLease := &leaseStatic{
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
Hostname: staticName,
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ok := t.Run("status", func(t *testing.T) {
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, s, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_static_lease", func(t *testing.T) {
|
||||
w := handleLease(t, staticLease, s.handleDHCPAddStaticLease)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.StaticLeases = []*leaseStatic{staticLease}
|
||||
|
||||
checkStatus(t, s, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_invalid_lease", func(t *testing.T) {
|
||||
w := handleLease(t, staticLease, s.handleDHCPAddStaticLease)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("remove_static_lease", func(t *testing.T) {
|
||||
w := handleLease(t, staticLease, s.handleDHCPRemoveStaticLease)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, s, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("set_config", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.Enabled = false
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPSetConfig(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
checkStatus(t, s, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestServer_HandleUpdateStaticLease(t *testing.T) {
|
||||
const (
|
||||
leaseV4Name = "static-client-v4"
|
||||
leaseV4MAC = "44:44:44:44:44:44"
|
||||
|
||||
leaseV6Name = "static-client-v6"
|
||||
leaseV6MAC = "66:66:66:66:66:66"
|
||||
)
|
||||
|
||||
leaseV4IP := netip.MustParseAddr("192.168.10.10")
|
||||
leaseV6IP := netip.MustParseAddr("2001::6")
|
||||
|
||||
const (
|
||||
leaseV4Pos = iota
|
||||
leaseV6Pos
|
||||
)
|
||||
|
||||
leases := []*leaseStatic{
|
||||
leaseV4Pos: {
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: leaseV4IP,
|
||||
Hostname: leaseV4Name,
|
||||
},
|
||||
leaseV6Pos: {
|
||||
HWAddr: leaseV6MAC,
|
||||
IP: leaseV6IP,
|
||||
Hostname: leaseV6Name,
|
||||
},
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, l := range leases {
|
||||
w := handleLease(t, l, s.handleDHCPAddStaticLease)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
pos int
|
||||
lease *leaseStatic
|
||||
}{{
|
||||
name: "update_v4_name",
|
||||
pos: leaseV4Pos,
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: leaseV4IP,
|
||||
Hostname: "updated-client-v4",
|
||||
},
|
||||
}, {
|
||||
name: "update_v4_ip",
|
||||
pos: leaseV4Pos,
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: netip.MustParseAddr("192.168.10.200"),
|
||||
Hostname: "updated-client-v4",
|
||||
},
|
||||
}, {
|
||||
name: "update_v6_name",
|
||||
pos: leaseV6Pos,
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV6MAC,
|
||||
IP: leaseV6IP,
|
||||
Hostname: "updated-client-v6",
|
||||
},
|
||||
}, {
|
||||
name: "update_v6_ip",
|
||||
pos: leaseV6Pos,
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV6MAC,
|
||||
IP: netip.MustParseAddr("2001::666"),
|
||||
Hostname: "updated-client-v6",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
w := handleLease(t, tc.lease, s.handleDHCPUpdateStaticLease)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
leases[tc.pos] = tc.lease
|
||||
resp.StaticLeases = leases
|
||||
|
||||
checkStatus(t, s, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_HandleUpdateStaticLease_validation(t *testing.T) {
|
||||
const (
|
||||
leaseV4Name = "static-client-v4"
|
||||
leaseV4MAC = "44:44:44:44:44:44"
|
||||
|
||||
anotherV4Name = "another-client-v4"
|
||||
anotherV4MAC = "55:55:55:55:55:55"
|
||||
)
|
||||
|
||||
leaseV4IP := netip.MustParseAddr("192.168.10.10")
|
||||
anotherV4IP := netip.MustParseAddr("192.168.10.20")
|
||||
|
||||
leases := []*leaseStatic{{
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: leaseV4IP,
|
||||
Hostname: leaseV4Name,
|
||||
}, {
|
||||
HWAddr: anotherV4MAC,
|
||||
IP: anotherV4IP,
|
||||
Hostname: anotherV4Name,
|
||||
}}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, l := range leases {
|
||||
w := handleLease(t, l, s.handleDHCPAddStaticLease)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
lease *leaseStatic
|
||||
name string
|
||||
want string
|
||||
}{{
|
||||
name: "v4_unknown_mac",
|
||||
lease: &leaseStatic{
|
||||
HWAddr: "aa:aa:aa:aa:aa:aa",
|
||||
IP: leaseV4IP,
|
||||
Hostname: leaseV4Name,
|
||||
},
|
||||
want: "dhcpv4: updating static lease: can't find lease aa:aa:aa:aa:aa:aa\n",
|
||||
}, {
|
||||
name: "update_v4_same_ip",
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: anotherV4IP,
|
||||
Hostname: leaseV4Name,
|
||||
},
|
||||
want: "dhcpv4: updating static lease: ip address is not unique\n",
|
||||
}, {
|
||||
name: "update_v4_same_name",
|
||||
lease: &leaseStatic{
|
||||
HWAddr: leaseV4MAC,
|
||||
IP: leaseV4IP,
|
||||
Hostname: anotherV4Name,
|
||||
},
|
||||
want: "dhcpv4: updating static lease: hostname is not unique\n",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
w := handleLease(t, tc.lease, s.handleDHCPUpdateStaticLease)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Equal(t, tc.want, w.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_handleDHCPStatus(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
staticMAC = "aa:aa:aa:aa:aa:aa"
|
||||
)
|
||||
|
||||
staticIP := netip.MustParseAddr("192.168.10.10")
|
||||
|
||||
staticLease := &leaseStatic{
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
Hostname: staticName,
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// checkStatus is a helper that asserts the response of
|
||||
// [*server.handleDHCPStatus].
|
||||
checkStatus := func(t *testing.T, want *dhcpStatusResponse) {
|
||||
w := httptest.NewRecorder()
|
||||
var req *http.Request
|
||||
req, err = http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&want)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPStatus(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
assert.JSONEq(t, b.String(), w.Body.String())
|
||||
}
|
||||
|
||||
// defaultResponse is a helper that returns the response with default
|
||||
// configuration.
|
||||
defaultResponse := func() *dhcpStatusResponse {
|
||||
conf4 := defaultV4ServerConf()
|
||||
conf4.LeaseDuration = 86400
|
||||
|
||||
resp := &dhcpStatusResponse{
|
||||
V4: *conf4,
|
||||
V6: V6ServerConf{},
|
||||
Leases: []*leaseDynamic{},
|
||||
StaticLeases: []*leaseStatic{},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
ok := t.Run("status", func(t *testing.T) {
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.StaticLeases = []*leaseStatic{staticLease}
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("add_invalid_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
err = json.NewEncoder(b).Encode(&leaseStatic{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPAddStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("remove_static_lease", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(staticLease)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPRemoveStaticLease(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp := defaultResponse()
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
|
||||
ok = t.Run("set_config", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
resp := defaultResponse()
|
||||
resp.Enabled = false
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err = json.NewEncoder(b).Encode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", b)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.handleDHCPSetConfig(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
checkStatus(t, resp)
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
@@ -43,6 +43,7 @@ func (s *server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
|
||||
func (winServer) getLeasesRef() []*Lease { return nil }
|
||||
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) UpdateStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) FindMACbyIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
|
||||
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
|
||||
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
|
||||
|
||||
@@ -148,6 +148,9 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
s.leasedOffsets = newBitSet()
|
||||
s.hostsIndex = make(map[string]*Lease, len(leases))
|
||||
s.ipIndex = make(map[netip.Addr]*Lease, len(leases))
|
||||
@@ -182,16 +185,13 @@ func (s *v4Server) isBlocklisted(l *Lease) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
ok = true
|
||||
for _, b := range l.HWAddr {
|
||||
if b != 0 {
|
||||
ok = false
|
||||
|
||||
break
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return ok
|
||||
return true
|
||||
}
|
||||
|
||||
// GetLeases returns the list of current DHCP leases. It is safe for concurrent
|
||||
@@ -309,9 +309,15 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrDupHostname is returned by addLease when the added lease has a not empty
|
||||
// non-unique hostname.
|
||||
const ErrDupHostname = errors.Error("hostname is not unique")
|
||||
const (
|
||||
// ErrDupHostname is returned by addLease, validateStaticLease when the
|
||||
// modified lease has a not empty non-unique hostname.
|
||||
ErrDupHostname = errors.Error("hostname is not unique")
|
||||
|
||||
// ErrDupIP is returned by addLease, validateStaticLease when the modified
|
||||
// lease has a non-unique IP address.
|
||||
ErrDupIP = errors.Error("ip address is not unique")
|
||||
)
|
||||
|
||||
// addLease adds a dynamic or static lease.
|
||||
func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
@@ -428,6 +434,81 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStaticLease updates IP, hostname of the static lease.
|
||||
func (s *v4Server) UpdateStaticLease(l *Lease) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.Annotate(err, "dhcpv4: updating static lease: %w")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.conf.notify(LeaseChangedRemovedStatic)
|
||||
}()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
found := s.findLease(l.HWAddr)
|
||||
if found == nil {
|
||||
return fmt.Errorf("can't find lease %s", l.HWAddr)
|
||||
}
|
||||
|
||||
err = s.validateStaticLease(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.rmLease(found)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing previous lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
}
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding updated static lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStaticLease returns an error if the static lease is invalid.
|
||||
func (s *v4Server) validateStaticLease(l *Lease) (err error) {
|
||||
hostname, err := normalizeHostname(l.Hostname)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = netutil.ValidateHostname(hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating hostname: %w", err)
|
||||
}
|
||||
|
||||
dup, ok := s.hostsIndex[hostname]
|
||||
if ok && !bytes.Equal(dup.HWAddr, l.HWAddr) {
|
||||
return ErrDupHostname
|
||||
}
|
||||
|
||||
dup, ok = s.ipIndex[l.IP]
|
||||
if ok && !bytes.Equal(dup.HWAddr, l.HWAddr) {
|
||||
return ErrDupIP
|
||||
}
|
||||
|
||||
l.Hostname = hostname
|
||||
|
||||
if gwIP := s.conf.GatewayIP; gwIP == l.IP {
|
||||
return fmt.Errorf("can't assign the gateway IP %q to the lease", gwIP)
|
||||
}
|
||||
|
||||
if sn := s.conf.subnet; !sn.Contains(l.IP) {
|
||||
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateStaticLease safe removes dynamic lease with the same properties and
|
||||
// then adds a static lease l.
|
||||
func (s *v4Server) updateStaticLease(l *Lease) (err error) {
|
||||
|
||||
@@ -90,6 +90,9 @@ func (s *v6Server) IPByHost(host string) (ip netip.Addr) {
|
||||
func (s *v6Server) ResetLeases(leases []*Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
s.leases = nil
|
||||
for _, l := range leases {
|
||||
ip := net.IP(l.IP.AsSlice())
|
||||
@@ -232,6 +235,37 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStaticLease updates IP, hostname of the static lease.
|
||||
func (s *v6Server) UpdateStaticLease(l *Lease) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.Annotate(err, "dhcpv6: updating static lease: %w")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.conf.notify(LeaseChangedRemovedStatic)
|
||||
}()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
found := s.findLease(l.HWAddr)
|
||||
if found == nil {
|
||||
return fmt.Errorf("can't find lease %s", l.HWAddr)
|
||||
}
|
||||
|
||||
err = s.rmLease(found)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing previous lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
}
|
||||
|
||||
s.addLease(l)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
|
||||
func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
@@ -283,16 +317,14 @@ func (s *v6Server) rmLease(lease *Lease) (err error) {
|
||||
return fmt.Errorf("lease not found")
|
||||
}
|
||||
|
||||
// Find lease by MAC
|
||||
func (s *v6Server) findLease(mac net.HardwareAddr) *Lease {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
// Find lease by MAC.
|
||||
func (s *v6Server) findLease(mac net.HardwareAddr) (lease *Lease) {
|
||||
for i := range s.leases {
|
||||
if bytes.Equal(mac, s.leases[i].HWAddr) {
|
||||
return s.leases[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -474,7 +506,14 @@ func (s *v6Server) process(msg *dhcpv6.Message, req, resp dhcpv6.DHCPv6) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lease := s.findLease(mac)
|
||||
var lease *Lease
|
||||
func() {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
lease = s.findLease(mac)
|
||||
}()
|
||||
|
||||
if lease == nil {
|
||||
log.Debug("dhcpv6: no lease for: %s", mac)
|
||||
|
||||
|
||||
@@ -12,10 +12,12 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -392,6 +394,122 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
return s.ipset.init(ipsets)
|
||||
}
|
||||
|
||||
// collectListenAddr adds addrPort to addrs. It also adds its port to
|
||||
// unspecPorts if its address is unspecified.
|
||||
func collectListenAddr(
|
||||
addrPort netip.AddrPort,
|
||||
addrs map[netip.AddrPort]unit,
|
||||
unspecPorts map[uint16]unit,
|
||||
) {
|
||||
if addrPort == (netip.AddrPort{}) {
|
||||
return
|
||||
}
|
||||
|
||||
addrs[addrPort] = unit{}
|
||||
if addrPort.Addr().IsUnspecified() {
|
||||
unspecPorts[addrPort.Port()] = unit{}
|
||||
}
|
||||
}
|
||||
|
||||
// collectDNSAddrs returns configured set of listening addresses. It also
|
||||
// returns a set of ports of each unspecified listening address.
|
||||
func (conf *ServerConfig) collectDNSAddrs() (
|
||||
addrs map[netip.AddrPort]unit,
|
||||
unspecPorts map[uint16]unit,
|
||||
) {
|
||||
// TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the
|
||||
// TCP and UDP listening addresses are currently the same.
|
||||
addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
|
||||
unspecPorts = map[uint16]unit{}
|
||||
|
||||
for _, laddr := range conf.TCPListenAddrs {
|
||||
collectListenAddr(laddr.AddrPort(), addrs, unspecPorts)
|
||||
}
|
||||
|
||||
for _, laddr := range conf.UDPListenAddrs {
|
||||
collectListenAddr(laddr.AddrPort(), addrs, unspecPorts)
|
||||
}
|
||||
|
||||
return addrs, unspecPorts
|
||||
}
|
||||
|
||||
// defaultPlainDNSPort is the default port for plain DNS.
|
||||
const defaultPlainDNSPort uint16 = 53
|
||||
|
||||
// addrPortMatcher is a function that matches an IP address with port.
|
||||
type addrPortMatcher func(addr netip.AddrPort) (ok bool)
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
// closing errors joined.
|
||||
func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
var errs []error
|
||||
delFunc := func(u upstream.Upstream) (ok bool) {
|
||||
// TODO(e.burkov): We should probably consider the protocol of u to
|
||||
// only filter out the listening addresses of the same protocol.
|
||||
addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort)
|
||||
if parseErr != nil || !m(addr) {
|
||||
// Don't filter out the upstream if it either cannot be parsed, or
|
||||
// does not match um.
|
||||
return false
|
||||
}
|
||||
|
||||
errs = append(errs, u.Close())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
upsConf.Upstreams = slices.DeleteFunc(upsConf.Upstreams, delFunc)
|
||||
for d, ups := range upsConf.DomainReservedUpstreams {
|
||||
upsConf.DomainReservedUpstreams[d] = slices.DeleteFunc(ups, delFunc)
|
||||
}
|
||||
for d, ups := range upsConf.SpecifiedDomainUpstreams {
|
||||
upsConf.SpecifiedDomainUpstreams[d] = slices.DeleteFunc(ups, delFunc)
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ourAddrsMatcher returns a matcher that matches all the configured listening
|
||||
// addresses.
|
||||
func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
|
||||
addrs, unspecPorts := conf.collectDNSAddrs()
|
||||
if len(addrs) == 0 {
|
||||
log.Debug("dnsforward: no listen addresses")
|
||||
|
||||
// Match no addresses.
|
||||
return func(_ netip.AddrPort) (ok bool) { return false }, nil
|
||||
}
|
||||
|
||||
if len(unspecPorts) == 0 {
|
||||
log.Debug("dnsforward: filtering out addresses %s", addrs)
|
||||
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
_, ok = addrs[a]
|
||||
|
||||
return ok
|
||||
}
|
||||
} else {
|
||||
var ifaceAddrs []netip.Addr
|
||||
ifaceAddrs, err = aghnet.CollectAllIfacesAddrs()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts)
|
||||
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
if _, ok = unspecPorts[a.Port()]; ok {
|
||||
return slices.Contains(ifaceAddrs, a.Addr())
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
|
||||
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {
|
||||
|
||||
@@ -25,8 +25,10 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
@@ -72,6 +74,11 @@ type DHCP interface {
|
||||
Enabled() (ok bool)
|
||||
}
|
||||
|
||||
type SystemResolvers interface {
|
||||
// Addrs returns the list of system resolvers' addresses.
|
||||
Addrs() (addrs []netip.AddrPort)
|
||||
}
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -126,7 +133,7 @@ type Server struct {
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
sysResolvers SystemResolvers
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect
|
||||
// and prevent recursive requests only for private upstreams.
|
||||
@@ -225,9 +232,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
anonymizer: p.Anonymizer,
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
// passes the public testing.
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
|
||||
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
@@ -439,73 +444,28 @@ func (s *Server) startLocked() error {
|
||||
// faster than ordinary upstreams.
|
||||
const defaultLocalTimeout = 1 * time.Second
|
||||
|
||||
// collectDNSIPAddrs returns IP addresses the server is listening on without
|
||||
// port numbers. For internal use only.
|
||||
func (s *Server) collectDNSIPAddrs() (addrs []string, err error) {
|
||||
addrs = make([]string, len(s.conf.TCPListenAddrs)+len(s.conf.UDPListenAddrs))
|
||||
var i int
|
||||
var ip net.IP
|
||||
for _, addr := range s.conf.TCPListenAddrs {
|
||||
if addr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip = addr.IP; ip.IsUnspecified() {
|
||||
return aghnet.CollectAllIfacesAddrs()
|
||||
}
|
||||
|
||||
addrs[i] = ip.String()
|
||||
i++
|
||||
}
|
||||
for _, addr := range s.conf.UDPListenAddrs {
|
||||
if addr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip = addr.IP; ip.IsUnspecified() {
|
||||
return aghnet.CollectAllIfacesAddrs()
|
||||
}
|
||||
|
||||
addrs[i] = ip.String()
|
||||
i++
|
||||
}
|
||||
|
||||
return addrs[:i], nil
|
||||
}
|
||||
|
||||
func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error) {
|
||||
var ourAddrs []string
|
||||
ourAddrs, err = s.collectDNSIPAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ourAddrsSet := stringutil.NewSet(ourAddrs...)
|
||||
log.Debug("dnsforward: filtering out %s", ourAddrsSet.String())
|
||||
|
||||
// TODO(e.burkov): The approach of subtracting sets of strings is not
|
||||
// really applicable here since in case of listening on all network
|
||||
// interfaces we should check the whole interface's network to cut off
|
||||
// all the loopback addresses as well.
|
||||
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
|
||||
}
|
||||
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. For
|
||||
// internal use only.
|
||||
func (s *Server) setupLocalResolvers() (err error) {
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
resolvers = s.sysResolvers.Get()
|
||||
bootstraps = nil
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
resolvers, err = s.filterOurDNSAddrs(resolvers)
|
||||
if err != nil {
|
||||
return err
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
filterConfig := false
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
resolvers = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
resolvers = append(resolvers, r.String())
|
||||
}
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
filterConfig = true
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
@@ -514,13 +474,18 @@ func (s *Server) setupLocalResolvers() (err error) {
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if filterConfig {
|
||||
if err = matcher.filterOut(uc); err != nil {
|
||||
return fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
|
||||
@@ -65,6 +65,9 @@ type jsonDNSConfig struct {
|
||||
// UpstreamMode defines the way DNS requests are constructed.
|
||||
UpstreamMode *string `json:"upstream_mode"`
|
||||
|
||||
// BlockedResponseTTL is the TTL for blocked responses.
|
||||
BlockedResponseTTL *uint32 `json:"blocked_response_ttl"`
|
||||
|
||||
// CacheSize in bytes.
|
||||
CacheSize *uint32 `json:"cache_size"`
|
||||
|
||||
@@ -115,6 +118,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
|
||||
fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS)
|
||||
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
|
||||
blockedResponseTTL := s.dnsFilter.BlockedResponseTTL()
|
||||
ratelimit := s.conf.Ratelimit
|
||||
|
||||
customIP := s.conf.EDNSClientSubnet.CustomIP
|
||||
@@ -138,9 +142,9 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
|
||||
defPTRUps, err := s.defaultLocalPTRUpstreams()
|
||||
if err != nil {
|
||||
log.Debug("getting dns configuration: %s", err)
|
||||
log.Error("dnsforward: %s", err)
|
||||
}
|
||||
|
||||
return &jsonDNSConfig{
|
||||
@@ -158,6 +162,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
EDNSCSUseCustom: &useCustom,
|
||||
DNSSECEnabled: &enableDNSSEC,
|
||||
DisableIPv6: &aaaaDisabled,
|
||||
BlockedResponseTTL: &blockedResponseTTL,
|
||||
CacheSize: &cacheSize,
|
||||
CacheMinTTL: &cacheMinTTL,
|
||||
CacheMaxTTL: &cacheMaxTTL,
|
||||
@@ -166,11 +171,29 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
ResolveClients: &resolveClients,
|
||||
UsePrivateRDNS: &usePrivateRDNS,
|
||||
LocalPTRUpstreams: &localPTRUpstreams,
|
||||
DefaultLocalPTRUpstreams: defLocalPTRUps,
|
||||
DefaultLocalPTRUpstreams: defPTRUps,
|
||||
DisabledUntil: protectionDisabledUntil,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
|
||||
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
|
||||
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
ups = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
ups = append(ups, r.String())
|
||||
}
|
||||
|
||||
return ups, nil
|
||||
}
|
||||
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig()
|
||||
@@ -204,7 +227,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
if _, err = upstream.NewResolver(b, nil); err != nil {
|
||||
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -321,6 +344,10 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6)
|
||||
}
|
||||
|
||||
if dc.BlockedResponseTTL != nil {
|
||||
s.dnsFilter.SetBlockedResponseTTL(*dc.BlockedResponseTTL)
|
||||
}
|
||||
|
||||
if dc.ProtectionEnabled != nil {
|
||||
s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled)
|
||||
}
|
||||
|
||||
@@ -28,17 +28,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeSystemResolvers is a mock aghnet.SystemResolvers implementation for
|
||||
// tests.
|
||||
type fakeSystemResolvers struct {
|
||||
// SystemResolvers is embedded here simply to make *fakeSystemResolvers
|
||||
// an aghnet.SystemResolvers without actually implementing all methods.
|
||||
aghnet.SystemResolvers
|
||||
}
|
||||
// emptySysResolvers is an empty [SystemResolvers] implementation that always
|
||||
// returns nil.
|
||||
type emptySysResolvers struct{}
|
||||
|
||||
// Get implements the aghnet.SystemResolvers interface for *fakeSystemResolvers.
|
||||
// It always returns nil.
|
||||
func (fsr *fakeSystemResolvers) Get() (rs []string) {
|
||||
// Addrs implements the aghnet.SystemResolvers interface for emptySysResolvers.
|
||||
func (emptySysResolvers) Addrs() (addrs []netip.AddrPort) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,6 +55,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
BlockedResponseTTL: 10,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -78,7 +74,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &fakeSystemResolvers{}
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
require.NoError(t, s.Start())
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
@@ -137,6 +133,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
BlockedResponseTTL: 10,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -154,7 +151,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &fakeSystemResolvers{}
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
@@ -229,6 +226,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "fallbacks",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "blocked_response_ttl",
|
||||
wantSet: "",
|
||||
}}
|
||||
|
||||
var data map[string]struct {
|
||||
@@ -480,7 +480,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
||||
hostsUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
|
||||
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
||||
}).String()
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/ipset"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
@@ -15,14 +15,14 @@ import (
|
||||
|
||||
// ipsetCtx is the ipset context. ipsetMgr can be nil.
|
||||
type ipsetCtx struct {
|
||||
ipsetMgr aghnet.IpsetManager
|
||||
ipsetMgr ipset.Manager
|
||||
}
|
||||
|
||||
// init initializes the ipset context. It is not safe for concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Rewrite into a simple constructor?
|
||||
func (c *ipsetCtx) init(ipsetConf []string) (err error) {
|
||||
c.ipsetMgr, err = aghnet.NewIpsetManager(ipsetConf)
|
||||
c.ipsetMgr, err = ipset.NewManager(ipsetConf)
|
||||
if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
|
||||
// ipset cannot currently be initialized if the server was installed
|
||||
// from Snap or when the user or the binary doesn't have the required
|
||||
|
||||
@@ -114,3 +114,74 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIpsetCtx_SkipIpsetProcessing(t *testing.T) {
|
||||
req4 := createTestMessage("example.com")
|
||||
resp4 := &dns.Msg{
|
||||
Answer: []dns.RR{&dns.A{
|
||||
A: net.IP{1, 2, 3, 4},
|
||||
}},
|
||||
}
|
||||
|
||||
m := &fakeIpsetMgr{}
|
||||
ictx := &ipsetCtx{
|
||||
ipsetMgr: m,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
dctx *dnsContext
|
||||
name string
|
||||
want bool
|
||||
}{{
|
||||
name: "basic",
|
||||
want: false,
|
||||
dctx: &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req4,
|
||||
Res: resp4,
|
||||
},
|
||||
|
||||
responseFromUpstream: true,
|
||||
},
|
||||
}, {
|
||||
name: "rewrite",
|
||||
want: true,
|
||||
dctx: &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req4,
|
||||
Res: resp4,
|
||||
},
|
||||
|
||||
responseFromUpstream: false,
|
||||
},
|
||||
}, {
|
||||
name: "empty_req",
|
||||
want: true,
|
||||
dctx: &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: nil,
|
||||
Res: resp4,
|
||||
},
|
||||
|
||||
responseFromUpstream: true,
|
||||
},
|
||||
}, {
|
||||
name: "empty_res",
|
||||
want: true,
|
||||
dctx: &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req4,
|
||||
Res: nil,
|
||||
},
|
||||
|
||||
responseFromUpstream: true,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ictx.skipIpsetProcessing(tc.dctx)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -55,6 +56,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -90,6 +92,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -62,6 +63,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -100,6 +102,7 @@
|
||||
"blocking_mode": "refused",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -138,6 +141,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -176,6 +180,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -214,6 +219,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": true,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -254,6 +260,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": true,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -294,6 +301,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -332,6 +340,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": true,
|
||||
"disable_ipv6": false,
|
||||
@@ -370,6 +379,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -408,6 +418,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -446,6 +457,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -486,6 +498,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -526,6 +539,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -565,6 +579,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -603,6 +618,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -643,6 +659,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -686,6 +703,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -724,6 +742,7 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
@@ -766,6 +785,46 @@
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"blocked_response_ttl": {
|
||||
"req": {
|
||||
"blocked_response_ttl": 11
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:53",
|
||||
"8.8.4.4:53"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 11,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
|
||||
@@ -69,8 +69,8 @@ func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUpstreamConfig sets upstream configuration based on upstreams and
|
||||
// configuration of s.
|
||||
// prepareUpstreamConfig returns the upstream configuration based on upstreams
|
||||
// and configuration of s.
|
||||
func (s *Server) prepareUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
|
||||
@@ -504,7 +504,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, r.Close()) }()
|
||||
|
||||
bufPtr := d.bufPool.Get().(*[]byte)
|
||||
bufPtr := d.bufPool.Get()
|
||||
defer d.bufPool.Put(bufPtr)
|
||||
|
||||
p := rulelist.NewParser()
|
||||
@@ -607,7 +607,7 @@ func (d *DNSFilter) load(flt *FilterYAML) (err error) {
|
||||
|
||||
log.Debug("filtering: file %q, id %d, length %d", fileName, flt.ID, st.Size())
|
||||
|
||||
bufPtr := d.bufPool.Get().(*[]byte)
|
||||
bufPtr := d.bufPool.Get()
|
||||
defer d.bufPool.Put(bufPtr)
|
||||
|
||||
p := rulelist.NewParser()
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/syncutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -232,7 +233,7 @@ type Checker interface {
|
||||
// DNSFilter matches hostnames and DNS requests against filtering rules.
|
||||
type DNSFilter struct {
|
||||
// bufPool is a pool of buffers used for filtering-rule list parsing.
|
||||
bufPool *sync.Pool
|
||||
bufPool *syncutil.Pool[[]byte]
|
||||
|
||||
rulesStorage *filterlist.RuleStorage
|
||||
filteringEngine *urlfilter.DNSEngine
|
||||
@@ -514,8 +515,19 @@ func (d *DNSFilter) BlockingMode() (mode BlockingMode, bIPv4, bIPv6 netip.Addr)
|
||||
return d.conf.BlockingMode, d.conf.BlockingIPv4, d.conf.BlockingIPv6
|
||||
}
|
||||
|
||||
// SetBlockedResponseTTL sets TTL for blocked responses.
|
||||
func (d *DNSFilter) SetBlockedResponseTTL(ttl uint32) {
|
||||
d.confMu.Lock()
|
||||
defer d.confMu.Unlock()
|
||||
|
||||
d.conf.BlockedResponseTTL = ttl
|
||||
}
|
||||
|
||||
// BlockedResponseTTL returns TTL for blocked responses.
|
||||
func (d *DNSFilter) BlockedResponseTTL() (ttl uint32) {
|
||||
d.confMu.Lock()
|
||||
defer d.confMu.Unlock()
|
||||
|
||||
return d.conf.BlockedResponseTTL
|
||||
}
|
||||
|
||||
@@ -1050,13 +1062,7 @@ func InitModule() {
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
bufPool: &sync.Pool{
|
||||
New: func() (buf any) {
|
||||
bufVal := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
|
||||
return &bufVal
|
||||
},
|
||||
},
|
||||
bufPool: syncutil.NewSlicePool[byte](rulelist.DefaultRuleBufSize),
|
||||
refreshLock: &sync.Mutex{},
|
||||
safeBrowsingChecker: c.SafeBrowsingChecker,
|
||||
parentalControlChecker: c.ParentalControlChecker,
|
||||
|
||||
@@ -118,23 +118,22 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
// 2. wildcard > exact;
|
||||
// 3. lower level wildcard > higher level wildcard;
|
||||
func (rw *LegacyRewrite) Compare(b *LegacyRewrite) (res int) {
|
||||
if rw.Type == dns.TypeCNAME && b.Type != dns.TypeCNAME {
|
||||
return -1
|
||||
} else if rw.Type != dns.TypeCNAME && b.Type == dns.TypeCNAME {
|
||||
if rw.Type == dns.TypeCNAME {
|
||||
if b.Type != dns.TypeCNAME {
|
||||
return -1
|
||||
}
|
||||
} else if b.Type == dns.TypeCNAME {
|
||||
return 1
|
||||
}
|
||||
|
||||
aIsWld, bIsWld := isWildcard(rw.Domain), isWildcard(b.Domain)
|
||||
if aIsWld == bIsWld {
|
||||
if aIsWld, bIsWld := isWildcard(rw.Domain), isWildcard(b.Domain); aIsWld == bIsWld {
|
||||
// Both are either wildcards or both aren't.
|
||||
return len(rw.Domain) - len(b.Domain)
|
||||
}
|
||||
|
||||
if aIsWld {
|
||||
return len(b.Domain) - len(rw.Domain)
|
||||
} else if aIsWld {
|
||||
return 1
|
||||
} else {
|
||||
return -1
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
|
||||
@@ -80,6 +80,12 @@ func TestRewrites(t *testing.T) {
|
||||
}, {
|
||||
Domain: "*.issue4016.com",
|
||||
Answer: "sub.issue4016.com",
|
||||
}, {
|
||||
Domain: "*.sub.issue6226.com",
|
||||
Answer: addr2v4.String(),
|
||||
}, {
|
||||
Domain: "*.issue6226.com",
|
||||
Answer: addr1v4.String(),
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
@@ -182,6 +188,20 @@ func TestRewrites(t *testing.T) {
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue6226",
|
||||
host: "www.issue6226.com",
|
||||
wantCName: "",
|
||||
wantIPs: []netip.Addr{addr1v4},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue6226_sub",
|
||||
host: "www.sub.issue6226.com",
|
||||
wantCName: "",
|
||||
wantIPs: []netip.Addr{addr2v4},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -340,7 +340,7 @@ var blockedServices = []blockedService{{
|
||||
}, {
|
||||
ID: "bilibili",
|
||||
Name: "Bilibili",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 48 48\"><path fill=\"currentColor\" d=\"M36.5,12h-7.086l3.793-3.793c0.391-0.391,0.391-1.023,0-1.414s-1.023-0.391-1.414,0L26.586,12 h-5.172l-5.207-5.207c-0.391-0.391-1.023-0.391-1.414,0s-0.391,1.023,0,1.414L18.586,12H12.5C9.467,12,7,14.467,7,17.5v15 c0,3.033,2.467,5.5,5.5,5.5h2c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5h14c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5h2 c3.033,0,5.5-2.467,5.5-5.5v-15C42,14.467,39.533,12,36.5,12z M39,32.5c0,1.378-1.122,2.5-2.5,2.5h-24c-1.378,0-2.5-1.122-2.5-2.5 v-15c0-1.378,1.122-2.5,2.5-2.5h24c1.378,0,2.5,1.122,2.5,2.5V32.5z\"/><rect width=\"2.75\" height=\"7.075\" x=\"30.625\" y=\"18.463\" fill=\"currentColor\" transform=\"rotate(-71.567 32.001 22)\"/><rect width=\"7.075\" height=\"2.75\" x=\"14.463\" y=\"20.625\" fill=\"currentColor\" transform=\"rotate(-18.432 17.998 21.997)\"/><path fill=\"currentColor\" d=\"M28.033,27.526c-0.189,0.593-0.644,0.896-1.326,0.896c-0.076-0.013-0.139-0.013-0.24-0.025 c-0.013,0-0.05-0.013-0.063,0c-0.341-0.05-0.745-0.177-1.061-0.467c-0.366-0.265-0.808-0.745-0.947-1.477 c0,0-0.29,1.174-0.896,1.49c-0.076,0.05-0.164,0.114-0.253,0.164l-0.038,0.025c-0.303,0.164-0.682,0.265-1.086,0.278 c-0.568-0.051-0.947-0.328-1.136-0.821l-0.063-0.164l-1.427,0.656l0.05,0.139c0.467,1.124,1.465,1.768,2.74,1.768 c0.922,0,1.667-0.303,2.209-0.909c0.556,0.606,1.288,0.909,2.209,0.909c1.856,0,2.55-1.288,2.765-1.843l0.051-0.126l-1.427-0.657 L28.033,27.526z\"/></svg>"),
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 48 48\" fill=\"currentColor\"><path d=\"M36.5 12h-7.09l3.8-3.8a1 1 0 1 0-1.42-1.4L26.6 12H21.4l-5.2-5.2a1 1 0 1 0-1.42 1.4l3.8 3.8H12.5A5.5 5.5 0 0 0 7 17.5v15a5.5 5.5 0 0 0 5.5 5.5h2a1.5 1.5 0 1 0 3 0h14a1.5 1.5 0 1 0 3 0h2a5.5 5.5 0 0 0 5.5-5.5v-15a5.5 5.5 0 0 0-5.5-5.5ZM39 32.5a2.5 2.5 0 0 1-2.5 2.5h-24a2.5 2.5 0 0 1-2.5-2.5v-15a2.5 2.5 0 0 1 2.5-2.5h24a2.5 2.5 0 0 1 2.5 2.5v15Z\"/><path d=\"m29.08 19.58-.87 2.6 6.71 2.24.87-2.6-6.71-2.24Zm-8.16 0-6.7 2.23.86 2.61 6.71-2.23-.87-2.61Zm7.11 7.95c-.19.59-.64.9-1.32.9l-.24-.03c-.02 0-.05-.02-.07 0a1.99 1.99 0 0 1-1.06-.47 2.37 2.37 0 0 1-.94-1.48s-.3 1.18-.9 1.5l-.25.16-.04.02a2.47 2.47 0 0 1-1.09.28c-.56-.05-.94-.33-1.13-.82l-.07-.17-1.42.66.05.14a2.82 2.82 0 0 0 2.74 1.77c.92 0 1.66-.3 2.2-.91.56.6 1.3.9 2.22.9a2.82 2.82 0 0 0 2.76-1.84l.05-.12-1.43-.66-.06.17Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"|upos-hz-mirrorakam.akamaized.net^",
|
||||
"||acg.tv^",
|
||||
@@ -1709,7 +1709,6 @@ var blockedServices = []blockedService{{
|
||||
"||mastodon.world^",
|
||||
"||mastodon.zaclys.com^",
|
||||
"||mastodonapp.uk^",
|
||||
"||mastodonners.nl^",
|
||||
"||mastodont.cat^",
|
||||
"||mastodontech.de^",
|
||||
"||mastodontti.fi^",
|
||||
@@ -1741,6 +1740,7 @@ var blockedServices = []blockedService{{
|
||||
"||social.anoxinon.de^",
|
||||
"||social.cologne^",
|
||||
"||social.dev-wiki.de^",
|
||||
"||social.linux.pizza^",
|
||||
"||social.politicaconciencia.org^",
|
||||
"||social.vivaldi.net^",
|
||||
"||stranger.social^",
|
||||
@@ -1985,7 +1985,7 @@ var blockedServices = []blockedService{{
|
||||
}, {
|
||||
ID: "qq",
|
||||
Name: "QQ",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 32 32\"><g fill=\"none\" fillRule=\"evenodd\"><path d=\"M0 0h32v32H0z\" /><g fill=\"currentColor\" fillRule=\"nonzero\"><path d=\"M11.25 32C8.342 32 6 30.74 6 29.242c0-1.497 2.342-2.757 5.25-2.757s5.25 1.26 5.25 2.757S14.158 32 11.25 32zM27 29.242c0-1.497-2.342-2.757-5.25-2.757s-5.25 1.26-5.25 2.757S18.842 32 21.75 32 27 30.74 27 29.242zM14.885 7.182c0 .63-.323 1.182-.808 1.182-.485 0-.808-.552-.808-1.182 0-.63.323-1.182.808-1.182.485 0 .808.552.808 1.182zM18.923 6c-.485 0-.808.552-.808 1.182 0 .63.323-.394.808-.394.485 0 .808 1.024.808.394S19.408 6 18.923 6z\" /><path d=\"M6.653 12.638s4.685 2.465 9.926 2.465c5.242 0 9.927-2.465 9.927-2.465.112-.09.217-.161.316-.212-.002-1.088-.078-2.026-.078-2.808C26.744 4.292 22.138 0 16.5 0S6.176 4.292 6.176 9.618v2.78c.146.042.3.113.477.24zm12.626-8.664c1.112 0 1.986 1.272 1.986 2.782s-.874 2.782-1.986 2.782c-1.111 0-1.985-1.271-1.985-2.782 0-1.51.874-2.782 1.985-2.782zm-5.558 0c1.111 0 1.985 1.272 1.985 2.782s-.874 2.782-1.985 2.782c-1.112 0-1.986-1.271-1.986-2.782 0-1.51.874-2.782 1.986-2.782zm2.779 6.624c2.912 0 5.294.464 5.294.994s-2.382 1.656-5.294 1.656c-2.912 0-5.294-1.126-5.294-1.656s2.382-.994 5.294-.994zm11.374 5.182c-.058.038-.108.076-.177.117-.159.08-5.241 3.18-11.038 3.18-1.43 0-2.7-.239-3.97-.477-.239 1.67-.239 3.259-.239 3.974 0 1.272-1.032 1.193-2.303 1.272-1.27 0-2.223.16-2.303-1.033 0-.16-.08-2.782.397-5.564-1.588-.716-2.62-1.272-2.7-1.352a3.293 3.293 0 01-.335-.216C4.012 17.55 3 19.598 3 21.223c0 3.815 1.112 3.418 1.112 3.418.476 0 1.27-.795 1.985-1.67C7.765 27.662 11.735 31 16.5 31c4.765 0 8.735-3.338 10.403-8.028.715.874 1.509 1.669 1.985 1.669 0 0 1.112.397 1.112-3.418 0-1.588-.968-3.631-2.126-5.443z\" /></g></g></svg>"),
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 32 32\" fill=\"currentColor\"><path d=\"M11.25 32C8.35 32 6 30.74 6 29.24c0-1.5 2.34-2.75 5.25-2.75s5.25 1.26 5.25 2.75S14.16 32 11.25 32ZM27 29.24c0-1.5-2.34-2.75-5.25-2.75s-5.25 1.26-5.25 2.75S18.84 32 21.75 32 27 30.74 27 29.24ZM14.88 7.18c0 .63-.32 1.18-.8 1.18-.49 0-.81-.55-.81-1.18 0-.63.32-1.18.8-1.18.5 0 .81.55.81 1.18ZM18.93 6c-.48 0-.8.55-.8 1.18 0 .63.32-.4.8-.4.49 0 .81 1.03.81.4S19.41 6 18.93 6Z\"/><path d=\"M6.65 12.64s4.69 2.46 9.93 2.46c5.24 0 9.93-2.46 9.93-2.46.1-.1.21-.16.31-.21 0-1.1-.08-2.03-.08-2.81C26.74 4.29 22.14 0 16.5 0S6.18 4.3 6.18 9.62v2.78c.14.04.3.11.47.24Zm12.63-8.67c1.11 0 1.98 1.28 1.98 2.79s-.87 2.78-1.98 2.78c-1.11 0-1.99-1.27-1.99-2.78 0-1.51.88-2.79 1.99-2.79Zm-5.56 0c1.11 0 1.99 1.28 1.99 2.79s-.88 2.78-1.99 2.78c-1.11 0-1.99-1.27-1.99-2.78 0-1.51.88-2.79 2-2.79Zm2.78 6.63c2.91 0 5.3.46 5.3 1s-2.39 1.65-5.3 1.65c-2.91 0-5.3-1.13-5.3-1.66s2.39-1 5.3-1Zm11.37 5.18-.17.12c-.16.08-5.24 3.18-11.04 3.18-1.43 0-2.7-.24-3.97-.48-.24 1.67-.24 3.26-.24 3.97 0 1.28-1.03 1.2-2.3 1.28-1.27 0-2.23.16-2.3-1.04 0-.16-.09-2.78.4-5.56a23.87 23.87 0 0 1-2.7-1.35 3.3 3.3 0 0 1-.34-.22C4 17.55 3 19.6 3 21.22c0 3.82 1.11 3.42 1.11 3.42.48 0 1.27-.8 1.99-1.67C7.77 27.67 11.73 31 16.5 31c4.76 0 8.73-3.34 10.4-8.03.72.88 1.51 1.67 1.99 1.67 0 0 1.11.4 1.11-3.42 0-1.58-.97-3.63-2.13-5.44Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||qq-video.cdn-go.cn^",
|
||||
"||qq.com^$denyallow=wx.qq.com|weixin.qq.com",
|
||||
@@ -2437,7 +2437,7 @@ var blockedServices = []blockedService{{
|
||||
}, {
|
||||
ID: "xboxlive",
|
||||
Name: "Xbox Live",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 84 84\" xml:space=\"preserve\"><g transform=\"translate(-59.355 -42.513)\"><g transform=\"matrix(.26458 0 0 .26458 -145.88 -71.438)\" fill=\"currentColor\"><path d=\"m936.09 559.48s0.4 0 0 0c48.4 36.8 130.4 127.2 105.6 152.8-28.4 24.8-65.2 39.6-105.6 39.6s-77.6-14.8-105.6-39.6c-25.2-25.6 57.2-116 105.2-152.4 0-0.4 0.4-0.4 0.4-0.4zm83.6-105.2c-24.4-14.8-51.2-23.6-83.6-23.6s-59.2 8.8-83.6 23.6c-0.4 0-0.4 0.4-0.4 0.8s0.4 0.4 0.8 0.4c31.2-6.8 78.4 20 82.8 22.8h0.8c4.4-2.8 51.6-29.6 82.8-22.8 0.4 0 0.8 0 0.8-0.4s0-0.8-0.4-0.8zm-196 22.4c-0.4 0-0.4 0.4-0.8 0.4-29.2 29.2-47.2 69.6-47.2 114 0 36.4 12.4 70.4 32.8 97.2 0 0.4 0.4 0.4 0.8 0.4s0.4-0.4 0-0.8c-12.4-38 50.4-129.6 82.8-168l0.4-0.4c0-0.4 0-0.4-0.4-0.4-49.2-48.8-65.6-43.6-68.4-42.4zm156.4 42-0.4 0.4s0 0.4 0.4 0.4c32.4 38.4 94.8 130 82.8 168v0.8c0.4 0 0.8 0 0.8-0.4 20.4-26.8 32.8-60.8 32.8-97.2 0-44.4-18-84.8-47.6-114-0.4-0.4-0.4-0.4-0.8-0.4-2.4-0.8-18.8-6-68 42.4z\"/></g></g></svg>"),
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 84 84\" fill=\"currentColor\"><path d=\"M42.44 34.08c12.8 9.73 34.5 33.65 27.94 40.42a42.24 42.24 0 0 1-27.94 10.48A42.03 42.03 0 0 1 14.5 74.5c-6.67-6.77 15.13-30.69 27.83-40.32 0-.1.1-.1.1-.1ZM64.56 6.24A41.32 41.32 0 0 0 42.43 0a41.32 41.32 0 0 0-22.11 6.24c-.1 0-.1.1-.1.21s.1.11.2.11c8.26-1.8 20.75 5.3 21.91 6.03h.21c1.17-.74 13.65-7.83 21.9-6.03.12 0 .22 0 .22-.1s0-.22-.1-.22ZM12.7 12.17c-.1 0-.1.1-.21.1a42.56 42.56 0 0 0-3.81 55.88c0 .11.1.11.2.11s.11-.1 0-.21C5.62 57.99 22.23 33.75 30.8 23.6l.11-.1c0-.11 0-.11-.1-.11-13.02-12.91-17.36-11.54-18.1-11.22Zm41.38 11.11-.1.1s0 .11.1.11c8.57 10.16 25.08 34.4 21.9 44.45v.22c.11 0 .22 0 .22-.11a42.53 42.53 0 0 0 8.67-25.72 42.21 42.21 0 0 0-12.59-30.16c-.1-.1-.1-.1-.21-.1-.64-.22-4.97-1.6-18 11.21Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||gamepass.com^",
|
||||
"||xbox-global.ifs.windows.com^",
|
||||
|
||||
@@ -644,24 +644,27 @@ func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
return &authHandler{handler}
|
||||
}
|
||||
|
||||
// UserAdd - add new user
|
||||
func (a *Auth) UserAdd(u *webUser, password string) {
|
||||
// Add adds a new user with the given password.
|
||||
func (a *Auth) Add(u *webUser, password string) (err error) {
|
||||
if len(password) == 0 {
|
||||
return
|
||||
return errors.Error("empty password")
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Error("bcrypt.GenerateFromPassword: %s", err)
|
||||
return
|
||||
return fmt.Errorf("generating hash: %w", err)
|
||||
}
|
||||
|
||||
u.PasswordHash = string(hash)
|
||||
|
||||
a.lock.Lock()
|
||||
a.users = append(a.users, *u)
|
||||
a.lock.Unlock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
log.Debug("auth: added user: %s", u.Name)
|
||||
a.users = append(a.users, *u)
|
||||
|
||||
log.Debug("auth: added user with login %q", u.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findUser returns a user if there is one.
|
||||
|
||||
@@ -47,7 +47,8 @@ func TestAuth(t *testing.T) {
|
||||
s := session{}
|
||||
|
||||
user := webUser{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
err := a.Add(&user, "password")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/josharian/native"
|
||||
)
|
||||
@@ -83,12 +83,7 @@ func glGetTokenDate(file string) uint32 {
|
||||
}
|
||||
}()
|
||||
|
||||
fileReader, err := aghio.LimitReader(f, MaxFileSize)
|
||||
if err != nil {
|
||||
log.Error("creating limited reader: %s", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
fileReader := ioutil.LimitReader(f, MaxFileSize)
|
||||
|
||||
var dateToken uint32
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ type httpPprofConfig struct {
|
||||
// not absolutely necessary.
|
||||
type dnsConfig struct {
|
||||
BindHosts []netip.Addr `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
// AnonymizeClientIP defines if clients' IP addresses should be anonymized
|
||||
// in query log and statistics.
|
||||
@@ -232,13 +232,13 @@ type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC int `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
PortHTTPS uint16 `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS uint16 `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
|
||||
// PortDNSCrypt is the port for DNSCrypt requests. If it's zero,
|
||||
// DNSCrypt is disabled.
|
||||
PortDNSCrypt int `yaml:"port_dnscrypt" json:"port_dnscrypt"`
|
||||
PortDNSCrypt uint16 `yaml:"port_dnscrypt" json:"port_dnscrypt"`
|
||||
// DNSCryptConfigFile is the path to the DNSCrypt config file. Must be
|
||||
// set if PortDNSCrypt is not zero.
|
||||
//
|
||||
@@ -262,7 +262,7 @@ type queryLogConfig struct {
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are flushed
|
||||
// to disk.
|
||||
MemSize uint32 `yaml:"size_memory"`
|
||||
MemSize int `yaml:"size_memory"`
|
||||
|
||||
// Enabled defines if the query log is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
@@ -554,10 +554,10 @@ func validateConfig() (err error) {
|
||||
}
|
||||
|
||||
// udpPort is the port number for UDP protocol.
|
||||
type udpPort int
|
||||
type udpPort uint16
|
||||
|
||||
// tcpPort is the port number for TCP protocol.
|
||||
type tcpPort int
|
||||
type tcpPort uint16
|
||||
|
||||
// addPorts is a helper for ports validation that skips zero ports.
|
||||
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
|
||||
|
||||
@@ -24,11 +24,9 @@ import (
|
||||
// addresses to a slice of strings.
|
||||
func appendDNSAddrs(dst []string, addrs ...netip.Addr) (res []string) {
|
||||
for _, addr := range addrs {
|
||||
var hostport string
|
||||
if config.DNS.Port != defaultPortDNS {
|
||||
hostport = netip.AddrPortFrom(addr, uint16(config.DNS.Port)).String()
|
||||
} else {
|
||||
hostport = addr.String()
|
||||
hostport := addr.String()
|
||||
if p := config.DNS.Port; p != defaultPortDNS {
|
||||
hostport = netutil.JoinHostPort(hostport, p)
|
||||
}
|
||||
|
||||
dst = append(dst, hostport)
|
||||
@@ -102,7 +100,7 @@ type statusResponse struct {
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
DNSPort uint16 `json:"dns_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
|
||||
// ProtectionDisabledDuration is the duration of the protection pause in
|
||||
@@ -340,7 +338,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
var (
|
||||
forceHTTPS bool
|
||||
serveHTTP3 bool
|
||||
portHTTPS int
|
||||
portHTTPS uint16
|
||||
)
|
||||
func() {
|
||||
config.RLock()
|
||||
@@ -394,7 +392,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
|
||||
// httpsURL returns a copy of u for redirection to the HTTPS version, taking the
|
||||
// hostname and the HTTPS port into account.
|
||||
func httpsURL(u *url.URL, host string, portHTTPS int) (redirectURL *url.URL) {
|
||||
func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL) {
|
||||
hostPort := host
|
||||
if portHTTPS != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, portHTTPS)
|
||||
|
||||
@@ -43,8 +43,8 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
data := getAddrsResponse{
|
||||
Version: version.Version(),
|
||||
|
||||
WebPort: defaultPortHTTP,
|
||||
DNSPort: defaultPortDNS,
|
||||
WebPort: int(defaultPortHTTP),
|
||||
DNSPort: int(defaultPortDNS),
|
||||
}
|
||||
|
||||
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
|
||||
@@ -64,7 +64,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Port uint16 `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
// TODO(a.garipov): Declare all port variables anywhere as uint16.
|
||||
reqPort := uint16(req.Web.Port)
|
||||
reqPort := req.Web.Port
|
||||
port := tcpPort(reqPort)
|
||||
addPorts(tcpPorts, port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
@@ -128,7 +128,7 @@ func (req *checkConfReq) validateDNS(
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := uint16(req.DNS.Port)
|
||||
port := req.DNS.Port
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
@@ -142,13 +142,13 @@ func (req *checkConfReq) validateDNS(
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
if !aghnet.IsAddrInUse(err) {
|
||||
return false, err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (req *checkConfReq) validateDNS(
|
||||
log.Error("disabling DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ func disableDNSStubListener() error {
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Port uint16 `json:"port"`
|
||||
}
|
||||
|
||||
type applyConfigReq struct {
|
||||
@@ -395,14 +395,14 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -413,10 +413,22 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
copyInstallSettings(curConfig, config)
|
||||
|
||||
Context.firstRun = false
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
err = Context.auth.Add(u, req.Password)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
@@ -430,11 +442,6 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
Context.auth.UserAdd(u, req.Password)
|
||||
|
||||
err = config.write()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
@@ -445,8 +452,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
web.conf.firstRun = false
|
||||
web.conf.BindHost = req.Web.IP
|
||||
web.conf.BindPort = req.Web.Port
|
||||
web.conf.BindAddr = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||
|
||||
registerControlHandlers(web)
|
||||
|
||||
@@ -487,9 +493,9 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
}
|
||||
|
||||
addrPort := config.HTTPConfig.Address
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || addrPort.Port() != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, req.Web.Port))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf(
|
||||
"checking address %s:%d: %w",
|
||||
|
||||
@@ -27,11 +27,11 @@ import (
|
||||
|
||||
// Default listening ports.
|
||||
const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
defaultPortHTTPS = 443
|
||||
defaultPortQUIC = 853
|
||||
defaultPortTLS = 853
|
||||
defaultPortDNS uint16 = 53
|
||||
defaultPortHTTP uint16 = 80
|
||||
defaultPortHTTPS uint16 = 443
|
||||
defaultPortQUIC uint16 = 853
|
||||
defaultPortTLS uint16 = 853
|
||||
)
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
@@ -196,27 +196,27 @@ func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
@@ -346,8 +346,8 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
hostname := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := hostname
|
||||
if tlsConf.PortHTTPS != defaultPortHTTPS {
|
||||
addr = netutil.JoinHostPort(addr, tlsConf.PortHTTPS)
|
||||
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
|
||||
addr = netutil.JoinHostPort(addr, p)
|
||||
}
|
||||
|
||||
de.https = (&url.URL{
|
||||
@@ -357,17 +357,17 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
if p := tlsConf.PortDNSOverTLS; p != 0 {
|
||||
de.tls = (&url.URL{
|
||||
Scheme: "tls",
|
||||
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverTLS),
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
if p := tlsConf.PortDNSOverQUIC; p != 0 {
|
||||
de.quic = (&url.URL{
|
||||
Scheme: "quic",
|
||||
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverQUIC),
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
}).String()
|
||||
}
|
||||
}
|
||||
@@ -494,7 +494,9 @@ func closeDNSServer() {
|
||||
Context.dnsServer = nil
|
||||
}
|
||||
|
||||
Context.filters.Close()
|
||||
if Context.filters != nil {
|
||||
Context.filters.Close()
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
@@ -327,7 +329,7 @@ func setupBindOpts(opts options) (err error) {
|
||||
if opts.bindPort != 0 {
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
config.HTTPConfig.Address.Addr(),
|
||||
uint16(opts.bindPort),
|
||||
opts.bindPort,
|
||||
)
|
||||
|
||||
err = checkPorts()
|
||||
@@ -495,8 +497,7 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
BindHost: config.HTTPConfig.Address.Addr(),
|
||||
BindPort: int(config.HTTPConfig.Address.Port()),
|
||||
BindAddr: config.HTTPConfig.Address,
|
||||
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
@@ -559,16 +560,28 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
err = setupOpts(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
execPath, err := os.Executable()
|
||||
fatalOnError(errors.Annotate(err, "getting executable path: %w"))
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
||||
}
|
||||
|
||||
upd := updater.NewUpdater(&updater.Config{
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: config.getConfigFilename(),
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: config.getConfigFilename(),
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: u.String(),
|
||||
})
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
@@ -839,7 +852,7 @@ func loadCmdLineOpts() (opts options) {
|
||||
// example:
|
||||
//
|
||||
// go to http://127.0.0.1:80
|
||||
func printWebAddrs(proto, addr string, port int) {
|
||||
func printWebAddrs(proto, addr string, port uint16) {
|
||||
log.Printf("go to %s://%s", proto, netutil.JoinHostPort(addr, port))
|
||||
}
|
||||
|
||||
@@ -851,7 +864,7 @@ func printHTTPAddresses(proto string) {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := int(config.HTTPConfig.Address.Port())
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
)
|
||||
|
||||
// middlerware is a wrapper function signature.
|
||||
@@ -23,12 +21,14 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// defaultReqBodySzLim is the default maximum request body size.
|
||||
const defaultReqBodySzLim = 64 * 1024
|
||||
const (
|
||||
// defaultReqBodySzLim is the default maximum request body size.
|
||||
defaultReqBodySzLim = 64 * 1024
|
||||
|
||||
// largerReqBodySzLim is the maximum request body size for APIs expecting larger
|
||||
// requests.
|
||||
const largerReqBodySzLim = 4 * 1024 * 1024
|
||||
// largerReqBodySzLim is the maximum request body size for APIs expecting
|
||||
// larger requests.
|
||||
largerReqBodySzLim = 4 * 1024 * 1024
|
||||
)
|
||||
|
||||
// expectsLargerRequests shows if this request should use a larger body size
|
||||
// limit. These are exceptions for poorly designed current APIs as well as APIs
|
||||
@@ -52,20 +52,12 @@ func expectsLargerRequests(r *http.Request) (ok bool) {
|
||||
// method limited.
|
||||
func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
|
||||
var szLim int64 = defaultReqBodySzLim
|
||||
var szLim uint64 = defaultReqBodySzLim
|
||||
if expectsLargerRequests(r) {
|
||||
szLim = largerReqBodySzLim
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
reader, err = aghio.LimitReader(r.Body, szLim)
|
||||
if err != nil {
|
||||
log.Error("limitRequestBody: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
reader := ioutil.LimitReader(r.Body, szLim)
|
||||
|
||||
// HTTP handlers aren't supposed to call r.Body.Close(), so just
|
||||
// replace the body in a clone.
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitRequestBody(t *testing.T) {
|
||||
errReqLimitReached := &aghio.LimitReachedError{
|
||||
errReqLimitReached := &ioutil.LimitError{
|
||||
Limit: defaultReqBodySzLim,
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ type options struct {
|
||||
// bindPort is the port on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindPort int
|
||||
bindPort uint16
|
||||
|
||||
// bindAddr is the address to serve the web UI on.
|
||||
bindAddr netip.AddrPort
|
||||
@@ -160,15 +160,11 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
shortName: "h",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
var err error
|
||||
var p int
|
||||
minPort, maxPort := 0, 1<<16-1
|
||||
if p, err = strconv.Atoi(v); err != nil {
|
||||
err = fmt.Errorf("port %q is not a number", v)
|
||||
} else if p < minPort || p > maxPort {
|
||||
err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
|
||||
p, err := strconv.ParseUint(v, 10, 16)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("parsing port: %w", err)
|
||||
} else {
|
||||
o.bindPort = p
|
||||
o.bindPort = uint16(p)
|
||||
}
|
||||
|
||||
return o, err
|
||||
@@ -180,7 +176,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
return "", false
|
||||
}
|
||||
|
||||
return strconv.Itoa(o.bindPort), true
|
||||
return strconv.Itoa(int(o.bindPort)), true
|
||||
},
|
||||
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
|
||||
longName: "port",
|
||||
|
||||
@@ -67,11 +67,11 @@ func TestParseBindHost(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseBindPort(t *testing.T) {
|
||||
assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
|
||||
assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
|
||||
assert.Equal(t, uint16(0), testParseOK(t).bindPort, "empty is port 0")
|
||||
assert.Equal(t, uint16(65535), testParseOK(t, "-p", "65535").bindPort, "-p is port")
|
||||
testParseParamMissing(t, "-p")
|
||||
|
||||
assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
|
||||
assert.Equal(t, uint16(65535), testParseOK(t, "--port", "65535").bindPort, "--port is port")
|
||||
testParseParamMissing(t, "--port")
|
||||
|
||||
testParseErr(t, "not an int", "-p", "x")
|
||||
|
||||
@@ -39,8 +39,8 @@ type webConfig struct {
|
||||
|
||||
clientFS fs.FS
|
||||
|
||||
BindHost netip.Addr
|
||||
BindPort int
|
||||
// BindAddr is the binding address with port for plain HTTP web interface.
|
||||
BindAddr netip.AddrPort
|
||||
|
||||
// ReadTimeout is an option to pass to http.Server for setting an
|
||||
// appropriate field.
|
||||
@@ -125,12 +125,12 @@ func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
// available, unless the HTTPS server isn't active.
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func webCheckPortAvailable(port int) (ok bool) {
|
||||
func webCheckPortAvailable(port uint16) (ok bool) {
|
||||
if Context.web.httpsServer.server != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
|
||||
|
||||
return aghnet.CheckPort("tcp", addrPort) == nil
|
||||
}
|
||||
@@ -185,10 +185,9 @@ func (web *webAPI) start() {
|
||||
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
|
||||
|
||||
// Create a new instance, because the Web is not usable after Shutdown.
|
||||
hostStr := web.conf.BindHost.String()
|
||||
web.httpServer = &http.Server{
|
||||
ErrorLog: log.StdLog("web: plain", log.DEBUG),
|
||||
Addr: netutil.JoinHostPort(hostStr, web.conf.BindPort),
|
||||
Addr: web.conf.BindAddr.String(),
|
||||
Handler: hdlr,
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
@@ -249,7 +248,7 @@ func (web *webAPI) tlsServerLoop() {
|
||||
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
|
||||
var portHTTPS int
|
||||
var portHTTPS uint16
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
@@ -257,7 +256,7 @@ func (web *webAPI) tlsServerLoop() {
|
||||
portHTTPS = config.TLS.PortHTTPS
|
||||
}()
|
||||
|
||||
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
|
||||
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: log.StdLog("web: https", log.DEBUG),
|
||||
Addr: addr,
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
package aghnet
|
||||
// Package ipset provides ipset functionality.
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// IpsetManager is the ipset manager interface.
|
||||
// Manager is the ipset manager interface.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type,
|
||||
// since ipset is exclusive to Linux?
|
||||
type IpsetManager interface {
|
||||
type Manager interface {
|
||||
Add(host string, ip4s, ip6s []net.IP) (n int, err error)
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
// NewIpsetManager returns a new ipset. IPv4 addresses are added to an ipset
|
||||
// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist.
|
||||
// NewManager returns a new ipset manager. IPv4 addresses are added to an
|
||||
// ipset with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must
|
||||
// exist.
|
||||
//
|
||||
// The syntax of the ipsetConf is:
|
||||
//
|
||||
@@ -22,10 +24,10 @@ type IpsetManager interface {
|
||||
//
|
||||
// If ipsetConf is empty, msg and err are nil. The error is of type
|
||||
// *aghos.UnsupportedError if the OS is not supported.
|
||||
func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) {
|
||||
func NewManager(ipsetConf []string) (mgr Manager, err error) {
|
||||
if len(ipsetConf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return newIpsetMgr(ipsetConf)
|
||||
return newManager(ipsetConf)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build linux
|
||||
|
||||
package aghnet
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -31,9 +31,9 @@ import (
|
||||
// 6. Run "sudo ipset list example_set". The Members field should contain the
|
||||
// resolved IP addresses.
|
||||
|
||||
// newIpsetMgr returns a new Linux ipset manager.
|
||||
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
|
||||
return newIpsetMgrWithDialer(ipsetConf, defaultDial)
|
||||
// newManager returns a new Linux ipset manager.
|
||||
func newManager(ipsetConf []string) (set Manager, err error) {
|
||||
return newManagerWithDialer(ipsetConf, defaultDial)
|
||||
}
|
||||
|
||||
// defaultDial is the default netfilter dialing function.
|
||||
@@ -53,50 +53,31 @@ type ipsetConn interface {
|
||||
Header(name string) (p *ipset.HeaderPolicy, err error)
|
||||
}
|
||||
|
||||
// ipsetDialer creates an ipsetConn.
|
||||
type ipsetDialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error)
|
||||
// dialer creates an ipsetConn.
|
||||
type dialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error)
|
||||
|
||||
// ipsetProps contains one Linux Netfilter ipset properties.
|
||||
type ipsetProps struct {
|
||||
// props contains one Linux Netfilter ipset properties.
|
||||
type props struct {
|
||||
name string
|
||||
family netfilter.ProtoFamily
|
||||
}
|
||||
|
||||
// unit is a convenient alias for struct{}.
|
||||
type unit = struct{}
|
||||
// manager is the Linux Netfilter ipset manager.
|
||||
type manager struct {
|
||||
nameToIpset map[string]props
|
||||
domainToIpsets map[string][]props
|
||||
|
||||
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
|
||||
type ipsInIpset map[ipInIpsetEntry]unit
|
||||
|
||||
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
|
||||
type ipInIpsetEntry struct {
|
||||
ipsetName string
|
||||
ipArr [net.IPv6len]byte
|
||||
}
|
||||
|
||||
// ipsetMgr is the Linux Netfilter ipset manager.
|
||||
type ipsetMgr struct {
|
||||
nameToIpset map[string]ipsetProps
|
||||
domainToIpsets map[string][]ipsetProps
|
||||
|
||||
dial ipsetDialer
|
||||
dial dialer
|
||||
|
||||
// mu protects all properties below.
|
||||
mu *sync.Mutex
|
||||
|
||||
// TODO(a.garipov): Currently, the ipset list is static, and we don't
|
||||
// read the IPs already in sets, so we can assume that all incoming IPs
|
||||
// are either added to all corresponding ipsets or not. When that stops
|
||||
// being the case, for example if we add dynamic reconfiguration of
|
||||
// ipsets, this map will need to become a per-ipset-name one.
|
||||
addedIPs ipsInIpset
|
||||
|
||||
ipv4Conn ipsetConn
|
||||
ipv6Conn ipsetConn
|
||||
}
|
||||
|
||||
// dialNetfilter establishes connections to Linux's netfilter module.
|
||||
func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) {
|
||||
func (m *manager) dialNetfilter(conf *netlink.Config) (err error) {
|
||||
// The kernel API does not actually require two sockets but package
|
||||
// github.com/digineo/go-ipset does.
|
||||
//
|
||||
@@ -145,7 +126,7 @@ func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
|
||||
}
|
||||
|
||||
// ipsetProps returns the properties of an ipset with the given name.
|
||||
func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) {
|
||||
func (m *manager) ipsetProps(name string) (set props, err error) {
|
||||
// The family doesn't seem to matter when we use a header query, so
|
||||
// query only the IPv4 one.
|
||||
//
|
||||
@@ -165,14 +146,14 @@ func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) {
|
||||
return set, fmt.Errorf("unexpected ipset family %d", family)
|
||||
}
|
||||
|
||||
return ipsetProps{
|
||||
return props{
|
||||
name: name,
|
||||
family: family,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipsets returns currently known ipsets.
|
||||
func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) {
|
||||
func (m *manager) ipsets(names []string) (sets []props, err error) {
|
||||
for _, name := range names {
|
||||
set, ok := m.nameToIpset[name]
|
||||
if ok {
|
||||
@@ -193,20 +174,18 @@ func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) {
|
||||
return sets, nil
|
||||
}
|
||||
|
||||
// newIpsetMgrWithDialer returns a new Linux ipset manager using the provided
|
||||
// newManagerWithDialer returns a new Linux ipset manager using the provided
|
||||
// dialer.
|
||||
func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManager, err error) {
|
||||
func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err error) {
|
||||
defer func() { err = errors.Annotate(err, "ipset: %w") }()
|
||||
|
||||
m := &ipsetMgr{
|
||||
m := &manager{
|
||||
mu: &sync.Mutex{},
|
||||
|
||||
nameToIpset: make(map[string]ipsetProps),
|
||||
domainToIpsets: make(map[string][]ipsetProps),
|
||||
nameToIpset: make(map[string]props),
|
||||
domainToIpsets: make(map[string][]props),
|
||||
|
||||
dial: dial,
|
||||
|
||||
addedIPs: make(ipsInIpset),
|
||||
}
|
||||
|
||||
err = m.dialNetfilter(&netlink.Config{})
|
||||
@@ -229,7 +208,7 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
|
||||
return nil, fmt.Errorf("config line at idx %d: %w", i, err)
|
||||
}
|
||||
|
||||
var ipsets []ipsetProps
|
||||
var ipsets []props
|
||||
ipsets, err = m.ipsets(ipsetNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
@@ -249,7 +228,7 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
|
||||
|
||||
// lookupHost find the ipsets for the host, taking subdomain wildcards into
|
||||
// account.
|
||||
func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) {
|
||||
func (m *manager) lookupHost(host string) (sets []props) {
|
||||
// Search for matching ipset hosts starting with most specific domain.
|
||||
// We could use a trie here but the simple, inefficient solution isn't
|
||||
// that expensive: ~10 ns for TLD + SLD vs. ~140 ns for 10 subdomains on
|
||||
@@ -274,25 +253,14 @@ func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) {
|
||||
|
||||
// addIPs adds the IP addresses for the host to the ipset. set must be same
|
||||
// family as set's family.
|
||||
func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err error) {
|
||||
func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error) {
|
||||
if len(ips) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var entries []*ipset.Entry
|
||||
var newAddedEntries []ipInIpsetEntry
|
||||
for _, ip := range ips {
|
||||
e := ipInIpsetEntry{
|
||||
ipsetName: set.name,
|
||||
}
|
||||
copy(e.ipArr[:], ip.To16())
|
||||
|
||||
if _, added := m.addedIPs[e]; added {
|
||||
continue
|
||||
}
|
||||
|
||||
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
|
||||
newAddedEntries = append(newAddedEntries, e)
|
||||
}
|
||||
|
||||
n = len(entries)
|
||||
@@ -315,21 +283,15 @@ func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err
|
||||
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
|
||||
}
|
||||
|
||||
// Only add these to the cache once we're sure that all of them were
|
||||
// actually sent to the ipset.
|
||||
for _, e := range newAddedEntries {
|
||||
m.addedIPs[e] = unit{}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// addToSets adds the IP addresses to the corresponding ipset.
|
||||
func (m *ipsetMgr) addToSets(
|
||||
func (m *manager) addToSets(
|
||||
host string,
|
||||
ip4s []net.IP,
|
||||
ip6s []net.IP,
|
||||
sets []ipsetProps,
|
||||
sets []props,
|
||||
) (n int, err error) {
|
||||
for _, set := range sets {
|
||||
var nn int
|
||||
@@ -356,8 +318,8 @@ func (m *ipsetMgr) addToSets(
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Add implements the IpsetManager interface for *ipsetMgr
|
||||
func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
// Add implements the [Manager] interface for *manager.
|
||||
func (m *manager) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -371,8 +333,8 @@ func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
return m.addToSets(host, ip4s, ip6s, sets)
|
||||
}
|
||||
|
||||
// Close implements the IpsetManager interface for *ipsetMgr.
|
||||
func (m *ipsetMgr) Close() (err error) {
|
||||
// Close implements the [Manager] interface for *manager.
|
||||
func (m *manager) Close() (err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build linux
|
||||
|
||||
package aghnet
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
// fakeIpsetConn is a fake ipsetConn for tests.
|
||||
type fakeIpsetConn struct {
|
||||
// fakeConn is a fake ipsetConn for tests.
|
||||
type fakeConn struct {
|
||||
ipv4Header *ipset.HeaderPolicy
|
||||
ipv4Entries *[]*ipset.Entry
|
||||
ipv6Header *ipset.HeaderPolicy
|
||||
ipv6Entries *[]*ipset.Entry
|
||||
}
|
||||
|
||||
// Add implements the ipsetConn interface for *fakeIpsetConn.
|
||||
func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
||||
// Add implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
||||
if strings.Contains(name, "ipv4") {
|
||||
*c.ipv4Entries = append(*c.ipv4Entries, entries...)
|
||||
|
||||
@@ -38,13 +38,13 @@ func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
||||
return errors.Error("test: ipset not found")
|
||||
}
|
||||
|
||||
// Close implements the ipsetConn interface for *fakeIpsetConn.
|
||||
func (c *fakeIpsetConn) Close() (err error) {
|
||||
// Close implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Header implements the ipsetConn interface for *fakeIpsetConn.
|
||||
func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
|
||||
// Header implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
|
||||
if strings.Contains(name, "ipv4") {
|
||||
return c.ipv4Header, nil
|
||||
} else if strings.Contains(name, "ipv6") {
|
||||
@@ -54,7 +54,7 @@ func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
|
||||
return nil, errors.Error("test: ipset not found")
|
||||
}
|
||||
|
||||
func TestIpsetMgr_Add(t *testing.T) {
|
||||
func TestManager_Add(t *testing.T) {
|
||||
ipsetConf := []string{
|
||||
"example.com,example.net/ipv4set",
|
||||
"example.org,example.biz/ipv6set",
|
||||
@@ -67,7 +67,7 @@ func TestIpsetMgr_Add(t *testing.T) {
|
||||
pf netfilter.ProtoFamily,
|
||||
conf *netlink.Config,
|
||||
) (conn ipsetConn, err error) {
|
||||
return &fakeIpsetConn{
|
||||
return &fakeConn{
|
||||
ipv4Header: &ipset.HeaderPolicy{
|
||||
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv4)),
|
||||
},
|
||||
@@ -79,7 +79,7 @@ func TestIpsetMgr_Add(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
m, err := newIpsetMgrWithDialer(ipsetConf, fakeDial)
|
||||
m, err := newManagerWithDialer(ipsetConf, fakeDial)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
@@ -114,21 +114,22 @@ func TestIpsetMgr_Add(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
var ipsetPropsSink []ipsetProps
|
||||
// ipsetPropsSink is the typed sink for benchmark results.
|
||||
var ipsetPropsSink []props
|
||||
|
||||
func BenchmarkIpsetMgr_lookupHost(b *testing.B) {
|
||||
propsLong := []ipsetProps{{
|
||||
func BenchmarkManager_LookupHost(b *testing.B) {
|
||||
propsLong := []props{{
|
||||
name: "example.com",
|
||||
family: netfilter.ProtoIPv4,
|
||||
}}
|
||||
|
||||
propsShort := []ipsetProps{{
|
||||
propsShort := []props{{
|
||||
name: "example.net",
|
||||
family: netfilter.ProtoIPv4,
|
||||
}}
|
||||
|
||||
m := &ipsetMgr{
|
||||
domainToIpsets: map[string][]ipsetProps{
|
||||
m := &manager{
|
||||
domainToIpsets: map[string][]props{
|
||||
"": propsLong,
|
||||
"example.net": propsShort,
|
||||
},
|
||||
@@ -1,11 +1,11 @@
|
||||
//go:build !linux
|
||||
|
||||
package aghnet
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func newIpsetMgr(_ []string) (mgr IpsetManager, err error) {
|
||||
func newManager(_ []string) (mgr Manager, err error) {
|
||||
return nil, aghos.Unsupported("ipset")
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -29,12 +30,12 @@ type queryLog struct {
|
||||
|
||||
findClient func(ids []string) (c *Client, err error)
|
||||
|
||||
// logFile is the path to the log file.
|
||||
logFile string
|
||||
|
||||
// buffer contains recent log entries. The entries in this buffer must not
|
||||
// be modified.
|
||||
buffer []*logEntry
|
||||
buffer *aghalg.RingBuffer[*logEntry]
|
||||
|
||||
// logFile is the path to the log file.
|
||||
logFile string
|
||||
|
||||
// bufferLock protects buffer.
|
||||
bufferLock sync.RWMutex
|
||||
@@ -195,7 +196,7 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
|
||||
// Add implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Add(params *AddParams) {
|
||||
var isEnabled, fileIsEnabled bool
|
||||
var memSize uint32
|
||||
var memSize int
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
@@ -204,7 +205,7 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
memSize = l.conf.MemSize
|
||||
}()
|
||||
|
||||
if !isEnabled {
|
||||
if !isEnabled || memSize == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -221,36 +222,18 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
|
||||
entry := newLogEntry(params)
|
||||
|
||||
needFlush := false
|
||||
func() {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
l.buffer = append(l.buffer, entry)
|
||||
l.buffer.Append(entry)
|
||||
|
||||
if !fileIsEnabled {
|
||||
if len(l.buffer) > int(memSize) {
|
||||
// Writing to file is disabled, so just remove the oldest entry
|
||||
// from the slices.
|
||||
//
|
||||
// TODO(a.garipov): This should be replaced by a proper ring
|
||||
// buffer, but it's currently difficult to do that.
|
||||
l.buffer[0] = nil
|
||||
l.buffer = l.buffer[1:]
|
||||
}
|
||||
} else if !l.flushPending {
|
||||
needFlush = len(l.buffer) >= int(memSize)
|
||||
if needFlush {
|
||||
l.flushPending = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
if !l.flushPending && fileIsEnabled && l.buffer.Len() >= memSize {
|
||||
l.flushPending = true
|
||||
|
||||
if needFlush {
|
||||
go func() {
|
||||
flushErr := l.flushLogBuffer()
|
||||
if flushErr != nil {
|
||||
log.Error("querylog: flushing after adding: %s", err)
|
||||
log.Error("querylog: flushing after adding: %s", flushErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -62,7 +63,7 @@ type Config struct {
|
||||
|
||||
// MemSize is the number of entries kept in a memory buffer before they are
|
||||
// flushed to disk.
|
||||
MemSize uint32
|
||||
MemSize int
|
||||
|
||||
// Enabled tells if the query log is enabled.
|
||||
Enabled bool
|
||||
@@ -142,9 +143,15 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if conf.MemSize < 0 {
|
||||
return nil, errors.Error("memory size must be greater or equal to zero")
|
||||
}
|
||||
|
||||
l = &queryLog{
|
||||
findClient: findClient,
|
||||
|
||||
buffer: aghalg.NewRingBuffer[*logEntry](conf.MemSize),
|
||||
|
||||
conf: &Config{},
|
||||
confMu: &sync.RWMutex{},
|
||||
logFile: filepath.Join(conf.BaseDir, queryLogFileName),
|
||||
|
||||
@@ -14,68 +14,75 @@ import (
|
||||
// flushLogBuffer flushes the current buffer to file and resets the current
|
||||
// buffer.
|
||||
func (l *queryLog) flushLogBuffer() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
var flushBuffer []*logEntry
|
||||
func() {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
b, err := l.encodeEntries()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
flushBuffer = l.buffer
|
||||
l.buffer = nil
|
||||
l.flushPending = false
|
||||
}()
|
||||
|
||||
err = l.flushToFile(flushBuffer)
|
||||
|
||||
return errors.Annotate(err, "writing to file: %w")
|
||||
return l.flushToFile(b)
|
||||
}
|
||||
|
||||
// flushToFile saves the specified log entries to the query log file
|
||||
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
|
||||
if len(buffer) == 0 {
|
||||
log.Debug("querylog: nothing to write to a file")
|
||||
// encodeEntries returns JSON encoded log entries, logs estimated time, clears
|
||||
// the log buffer.
|
||||
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
return nil
|
||||
bufLen := l.buffer.Len()
|
||||
if bufLen == 0 {
|
||||
return nil, errors.Error("nothing to write to a file")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var b bytes.Buffer
|
||||
e := json.NewEncoder(&b)
|
||||
for _, entry := range buffer {
|
||||
err = e.Encode(entry)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal entry: %s", err)
|
||||
b = &bytes.Buffer{}
|
||||
e := json.NewEncoder(b)
|
||||
|
||||
return err
|
||||
}
|
||||
l.buffer.Range(func(entry *logEntry) (cont bool) {
|
||||
err = e.Encode(entry)
|
||||
|
||||
return err == nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
|
||||
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen))
|
||||
|
||||
var zb bytes.Buffer
|
||||
filename := l.logFile
|
||||
zb = b
|
||||
l.buffer.Clear()
|
||||
l.flushPending = false
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// flushToFile saves the encoded log entries to the query log file.
|
||||
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
|
||||
l.fileWriteLock.Lock()
|
||||
defer l.fileWriteLock.Unlock()
|
||||
|
||||
filename := l.logFile
|
||||
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
log.Error("failed to create file \"%s\": %s", filename, err)
|
||||
return err
|
||||
return fmt.Errorf("creating file %q: %w", filename, err)
|
||||
}
|
||||
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
n, err := f.Write(zb.Bytes())
|
||||
n, err := f.Write(b.Bytes())
|
||||
if err != nil {
|
||||
log.Error("Couldn't write to file: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("writing to file %q: %w", filename, err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: ok \"%s\": %v bytes written", filename, n)
|
||||
log.Debug("querylog: ok %q: %v bytes written", filename, n)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,13 +51,12 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
// Go through the buffer in the reverse order, from newer to older.
|
||||
var err error
|
||||
for i := len(l.buffer) - 1; i >= 0; i-- {
|
||||
l.buffer.ReverseRange(func(entry *logEntry) (cont bool) {
|
||||
// A shallow clone is enough, since the only thing that this loop
|
||||
// modifies is the client field.
|
||||
e := l.buffer[i].shallowClone()
|
||||
e := entry.shallowClone()
|
||||
|
||||
var err error
|
||||
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
|
||||
if err != nil {
|
||||
msg := "querylog: enriching memory record at time %s" +
|
||||
@@ -70,9 +69,11 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
if params.match(e) {
|
||||
entries = append(entries, e)
|
||||
}
|
||||
}
|
||||
|
||||
return entries, len(l.buffer)
|
||||
return true
|
||||
})
|
||||
|
||||
return entries, l.buffer.Len()
|
||||
}
|
||||
|
||||
// search - searches log entries in the query log using specified parameters
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
@@ -51,11 +51,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
var r io.Reader
|
||||
r, err = aghio.LimitReader(resp.Body, MaxResponseSize)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
|
||||
}
|
||||
r := ioutil.LimitReader(resp.Body, MaxResponseSize)
|
||||
|
||||
// This use of ReadAll is safe, because we just limited the appropriate
|
||||
// ReadCloser.
|
||||
|
||||
148
internal/updater/check_test.go
Normal file
148
internal/updater/check_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package updater_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdater_VersionInfo(t *testing.T) {
|
||||
const jsonData = `{
|
||||
"version": "v0.103.0-beta.2",
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_windows_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
}`
|
||||
|
||||
counter := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
counter++
|
||||
_, _ = w.Write([]byte(jsonData))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
fakeURL, err := url.JoinPath(srv.URL, "adguardhome", version.ChannelBeta, "version.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
GOOS: "linux",
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 1)
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
|
||||
t.Run("cache_check", func(t *testing.T) {
|
||||
_, err = u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 1)
|
||||
})
|
||||
|
||||
t.Run("force_check", func(t *testing.T) {
|
||||
_, err = u.VersionInfo(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 2)
|
||||
})
|
||||
|
||||
t.Run("api_fail", func(t *testing.T) {
|
||||
srv.Close()
|
||||
|
||||
_, err = u.VersionInfo(true)
|
||||
var urlErr *url.Error
|
||||
assert.ErrorAs(t, err, &urlErr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
const jsonData = `{
|
||||
"version": "v0.103.0-beta.2",
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_mips_softfloat": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
|
||||
}`
|
||||
|
||||
fakeClient, fakeURL := aghtest.StartHTTPServer(t, []byte(jsonData))
|
||||
fakeURL = fakeURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
arch string
|
||||
arm string
|
||||
mips string
|
||||
}{{
|
||||
name: "ARM",
|
||||
arch: "arm",
|
||||
arm: "7",
|
||||
mips: "",
|
||||
}, {
|
||||
name: "MIPS",
|
||||
arch: "mips",
|
||||
mips: "softfloat",
|
||||
arm: "",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: fakeClient,
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOOS: "linux",
|
||||
GOARCH: tc.arch,
|
||||
GOARM: tc.arm,
|
||||
GOMIPS: tc.mips,
|
||||
VersionCheckURL: fakeURL.String(),
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
}
|
||||
}
|
||||
BIN
internal/updater/testdata/AdGuardHome_unix.tar.gz
vendored
Normal file
BIN
internal/updater/testdata/AdGuardHome_unix.tar.gz
vendored
Normal file
Binary file not shown.
@@ -8,18 +8,16 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@@ -36,6 +34,7 @@ type Updater struct {
|
||||
|
||||
workDir string
|
||||
confName string
|
||||
execPath string
|
||||
versionCheckURL string
|
||||
|
||||
// mu protects all fields below.
|
||||
@@ -74,18 +73,19 @@ type Config struct {
|
||||
// ConfName is the name of the current configuration file. Typically,
|
||||
// "AdGuardHome.yaml".
|
||||
ConfName string
|
||||
|
||||
// WorkDir is the working directory that is used for temporary files.
|
||||
WorkDir string
|
||||
|
||||
// ExecPath is path to the executable file.
|
||||
ExecPath string
|
||||
|
||||
// VersionCheckURL is url to the latest version announcement.
|
||||
VersionCheckURL string
|
||||
}
|
||||
|
||||
// NewUpdater creates a new Updater.
|
||||
func NewUpdater(conf *Config) *Updater {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", conf.Channel, "version.json"),
|
||||
}
|
||||
return &Updater{
|
||||
client: conf.Client,
|
||||
|
||||
@@ -98,7 +98,8 @@ func NewUpdater(conf *Config) *Updater {
|
||||
|
||||
confName: conf.ConfName,
|
||||
workDir: conf.WorkDir,
|
||||
versionCheckURL: u.String(),
|
||||
execPath: conf.ExecPath,
|
||||
versionCheckURL: conf.VersionCheckURL,
|
||||
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
@@ -119,12 +120,7 @@ func (u *Updater) Update(firstRun bool) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting executable path: %w", err)
|
||||
}
|
||||
|
||||
err = u.prepare(execPath)
|
||||
err = u.prepare()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing: %w", err)
|
||||
}
|
||||
@@ -178,7 +174,7 @@ func (u *Updater) VersionCheckURL() (vcu string) {
|
||||
}
|
||||
|
||||
// prepare fills all necessary fields in Updater object.
|
||||
func (u *Updater) prepare(exePath string) (err error) {
|
||||
func (u *Updater) prepare() (err error) {
|
||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
@@ -194,7 +190,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
updateExeName = "AdGuardHome.exe"
|
||||
}
|
||||
|
||||
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(exePath))
|
||||
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath))
|
||||
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
|
||||
|
||||
log.Debug(
|
||||
@@ -204,7 +200,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
u.packageURL,
|
||||
)
|
||||
|
||||
u.currentExeName = exePath
|
||||
u.currentExeName = u.execPath
|
||||
_, err = os.Stat(u.currentExeName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking %q: %w", u.currentExeName, err)
|
||||
@@ -332,11 +328,7 @@ func (u *Updater) downloadPackageFile() (err error) {
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
var r io.Reader
|
||||
r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
r := ioutil.LimitReader(resp.Body, MaxPackageFileSize)
|
||||
|
||||
log.Debug("updater: reading http body")
|
||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||
|
||||
107
internal/updater/updater_internal_test.go
Normal file
107
internal/updater/updater_internal_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdater_internal(t *testing.T) {
|
||||
wd := t.TempDir()
|
||||
|
||||
exePathUnix := filepath.Join(wd, "AdGuardHome.exe")
|
||||
exePathWindows := filepath.Join(wd, "AdGuardHome")
|
||||
yamlPath := filepath.Join(wd, "AdGuardHome.yaml")
|
||||
readmePath := filepath.Join(wd, "README.md")
|
||||
licensePath := filepath.Join(wd, "LICENSE.txt")
|
||||
|
||||
require.NoError(t, os.WriteFile(exePathUnix, []byte("AdGuardHome.exe"), 0o755))
|
||||
require.NoError(t, os.WriteFile(exePathWindows, []byte("AdGuardHome"), 0o755))
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("AdGuardHome.yaml"), 0o644))
|
||||
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
|
||||
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
exeName string
|
||||
os string
|
||||
archiveName string
|
||||
}{{
|
||||
name: "unix",
|
||||
os: "linux",
|
||||
exeName: "AdGuardHome",
|
||||
archiveName: "AdGuardHome.tar.gz",
|
||||
}, {
|
||||
name: "windows",
|
||||
os: "windows",
|
||||
exeName: "AdGuardHome.exe",
|
||||
archiveName: "AdGuardHome.zip",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
exePath := filepath.Join(wd, tc.exeName)
|
||||
|
||||
// start server for returning package file
|
||||
pkgData, err := os.ReadFile(filepath.Join("testdata", tc.archiveName))
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeClient, fakeURL := aghtest.StartHTTPServer(t, pkgData)
|
||||
fakeURL = fakeURL.JoinPath(tc.archiveName)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: fakeClient,
|
||||
GOOS: tc.os,
|
||||
Version: "v0.103.0",
|
||||
ExecPath: exePath,
|
||||
WorkDir: wd,
|
||||
ConfName: yamlPath,
|
||||
})
|
||||
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare())
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
|
||||
// check backup files
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.exeName, string(d))
|
||||
|
||||
// check updated files
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", string(d))
|
||||
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2", string(d))
|
||||
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "3", string(d))
|
||||
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
}
|
||||
}
|
||||
@@ -1,105 +1,38 @@
|
||||
package updater
|
||||
package updater_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Rewrite these tests.
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func startHTTPServer(data string) (l net.Listener, portStr string) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte(data))
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() { _ = http.Serve(listener, mux) }()
|
||||
return listener, strconv.FormatUint(uint64(listener.Addr().(*net.TCPAddr).Port), 10)
|
||||
}
|
||||
|
||||
func TestUpdateGetVersion(t *testing.T) {
|
||||
func TestUpdater_Update(t *testing.T) {
|
||||
const jsonData = `{
|
||||
"version": "v0.103.0-beta.2",
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_windows_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
"download_linux_amd64": "%s"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
const packagePath = "/AdGuardHome.tar.gz"
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: &http.Client{},
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
GOOS: "linux",
|
||||
})
|
||||
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", lport),
|
||||
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
|
||||
}
|
||||
u.versionCheckURL = fakeURL.String()
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
|
||||
// check cached
|
||||
_, err = u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
wd := t.TempDir()
|
||||
|
||||
exePath := filepath.Join(wd, "AdGuardHome")
|
||||
@@ -112,55 +45,61 @@ func TestUpdate(t *testing.T) {
|
||||
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
|
||||
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
|
||||
|
||||
// start server for returning package file
|
||||
pkgData, err := os.ReadFile("testdata/AdGuardHome.tar.gz")
|
||||
pkgData, err := os.ReadFile("testdata/AdGuardHome_unix.tar.gz")
|
||||
require.NoError(t, err)
|
||||
|
||||
l, lport := startHTTPServer(string(pkgData))
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: &http.Client{},
|
||||
Version: "v0.103.0",
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(packagePath, func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write(pkgData)
|
||||
})
|
||||
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", lport),
|
||||
Path: "AdGuardHome.tar.gz",
|
||||
}
|
||||
versionPath := path.Join("/adguardhome", version.ChannelBeta, "version.json")
|
||||
mux.HandleFunc(versionPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
var u string
|
||||
u, err = url.JoinPath("http://", r.Host, packagePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
u.workDir = wd
|
||||
u.confName = yamlPath
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
_, _ = fmt.Fprintf(w, jsonData, u)
|
||||
})
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// require.NoError(t, u.check())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
u.clean()
|
||||
versionCheckURL, err := url.JoinPath(srv.URL, versionPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
GOARCH: "amd64",
|
||||
GOOS: "linux",
|
||||
Version: "v0.103.0",
|
||||
ConfName: yamlPath,
|
||||
WorkDir: wd,
|
||||
ExecPath: exePath,
|
||||
VersionCheckURL: versionCheckURL,
|
||||
})
|
||||
|
||||
_, err = u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = u.Update(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// check backup files
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "LICENSE.txt"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
assert.Equal(t, "LICENSE.txt", string(d))
|
||||
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome"))
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "README.md"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome", string(d))
|
||||
assert.Equal(t, "README.md", string(d))
|
||||
|
||||
// check updated files
|
||||
d, err = os.ReadFile(exePath)
|
||||
_, err = os.Stat(exePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", string(d))
|
||||
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -175,157 +114,22 @@ func TestUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
}
|
||||
|
||||
func TestUpdateWindows(t *testing.T) {
|
||||
wd := t.TempDir()
|
||||
t.Run("config_check", func(t *testing.T) {
|
||||
// TODO(s.chzhen): Test on Windows also.
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping config check test on windows")
|
||||
}
|
||||
|
||||
exePath := filepath.Join(wd, "AdGuardHome.exe")
|
||||
yamlPath := filepath.Join(wd, "AdGuardHome.yaml")
|
||||
readmePath := filepath.Join(wd, "README.md")
|
||||
licensePath := filepath.Join(wd, "LICENSE.txt")
|
||||
|
||||
require.NoError(t, os.WriteFile(exePath, []byte("AdGuardHome.exe"), 0o755))
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("AdGuardHome.yaml"), 0o644))
|
||||
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
|
||||
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
|
||||
|
||||
// start server for returning package file
|
||||
pkgData, err := os.ReadFile("testdata/AdGuardHome.zip")
|
||||
require.NoError(t, err)
|
||||
|
||||
l, lport := startHTTPServer(string(pkgData))
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: &http.Client{},
|
||||
GOOS: "windows",
|
||||
Version: "v0.103.0",
|
||||
err = u.Update(false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", lport),
|
||||
Path: "AdGuardHome.zip",
|
||||
}
|
||||
t.Run("api_fail", func(t *testing.T) {
|
||||
srv.Close()
|
||||
|
||||
u.workDir = wd
|
||||
u.confName = yamlPath
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// assert.Nil(t, u.check())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
|
||||
// check backup files
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.exe"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.exe", string(d))
|
||||
|
||||
// check updated files
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", string(d))
|
||||
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2", string(d))
|
||||
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "3", string(d))
|
||||
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
}
|
||||
|
||||
func TestUpdater_VersionInto_ARM(t *testing.T) {
|
||||
const jsonData = `{
|
||||
"version": "v0.103.0-beta.2",
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: &http.Client{},
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
GOOS: "linux",
|
||||
GOARM: "7",
|
||||
err = u.Update(true)
|
||||
var urlErr *url.Error
|
||||
assert.ErrorAs(t, err, &urlErr)
|
||||
})
|
||||
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", lport),
|
||||
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
|
||||
}
|
||||
u.versionCheckURL = fakeURL.String()
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
}
|
||||
|
||||
func TestUpdater_VersionInto_MIPS(t *testing.T) {
|
||||
const jsonData = `{
|
||||
"version": "v0.103.0-beta.2",
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_linux_mips_softfloat": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: &http.Client{},
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "mips",
|
||||
GOOS: "linux",
|
||||
GOMIPS: "softfloat",
|
||||
})
|
||||
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", lport),
|
||||
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
|
||||
}
|
||||
u.versionCheckURL = fakeURL.String()
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -62,7 +62,7 @@ type Config struct {
|
||||
CacheTTL time.Duration
|
||||
|
||||
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
MaxConnReadSize int64
|
||||
MaxConnReadSize uint64
|
||||
|
||||
// MaxRedirects is the maximum redirects count.
|
||||
MaxRedirects int
|
||||
@@ -102,7 +102,7 @@ type Default struct {
|
||||
cacheTTL time.Duration
|
||||
|
||||
// maxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
maxConnReadSize int64
|
||||
maxConnReadSize uint64
|
||||
|
||||
// maxRedirects is the maximum redirects count.
|
||||
maxRedirects int
|
||||
@@ -208,11 +208,7 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
r, err := aghio.LimitReader(conn, w.maxConnReadSize)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
r := ioutil.LimitReader(conn, w.maxConnReadSize)
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(w.timeout))
|
||||
_, err = io.WriteString(conn, target+"\r\n")
|
||||
|
||||
Reference in New Issue
Block a user