all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-10-11 17:31:41 +03:00
parent 258eecc55b
commit 760d466b38
139 changed files with 39736 additions and 18364 deletions

View File

@@ -0,0 +1,102 @@
package aghalg
import (
"github.com/AdguardTeam/golibs/errors"
)
// RingBuffer is the implementation of ring buffer data structure.
type RingBuffer[T any] struct {
buf []T
cur int
full bool
}
// NewRingBuffer initializes the new instance of ring buffer. size must be
// greater or equal to zero.
func NewRingBuffer[T any](size int) (rb *RingBuffer[T]) {
if size < 0 {
panic(errors.Error("ring buffer: size must be greater or equal to zero"))
}
return &RingBuffer[T]{
buf: make([]T, size),
}
}
// Append appends an element to the buffer.
func (rb *RingBuffer[T]) Append(e T) {
if len(rb.buf) == 0 {
return
}
rb.buf[rb.cur] = e
rb.cur = (rb.cur + 1) % cap(rb.buf)
if rb.cur == 0 {
rb.full = true
}
}
// Range calls cb for each element of the buffer. If cb returns false it stops.
func (rb *RingBuffer[T]) Range(cb func(T) (cont bool)) {
before, after := rb.splitCur()
for _, e := range before {
if !cb(e) {
return
}
}
for _, e := range after {
if !cb(e) {
return
}
}
}
// ReverseRange calls cb for each element of the buffer in reverse order. If
// cb returns false it stops.
func (rb *RingBuffer[T]) ReverseRange(cb func(T) (cont bool)) {
before, after := rb.splitCur()
for i := len(after) - 1; i >= 0; i-- {
if !cb(after[i]) {
return
}
}
for i := len(before) - 1; i >= 0; i-- {
if !cb(before[i]) {
return
}
}
}
// splitCur splits the buffer in two, before and after current position in
// chronological order. If buffer is not full, after is nil.
func (rb *RingBuffer[T]) splitCur() (before, after []T) {
if len(rb.buf) == 0 {
return nil, nil
}
cur := rb.cur
if !rb.full {
return rb.buf[:cur], nil
}
return rb.buf[cur:], rb.buf[:cur]
}
// Len returns a length of the buffer.
func (rb *RingBuffer[T]) Len() (l int) {
if !rb.full {
return rb.cur
}
return cap(rb.buf)
}
// Clear clears the buffer.
func (rb *RingBuffer[T]) Clear() {
rb.full = false
rb.cur = 0
}

View File

@@ -0,0 +1,173 @@
package aghalg_test
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
)
// elements is a helper function that returns n elements of the buffer.
func elements(b *aghalg.RingBuffer[int], n int, reverse bool) (es []int) {
fn := b.Range
if reverse {
fn = b.ReverseRange
}
i := 0
fn(func(e int) (cont bool) {
if i >= n {
return false
}
es = append(es, e)
i++
return true
})
return es
}
func TestNewRingBuffer(t *testing.T) {
t.Run("success_and_clear", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](5)
for i := 0; i < 10; i++ {
b.Append(i)
}
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
b.Clear()
assert.Zero(t, b.Len())
})
t.Run("negative_size", func(t *testing.T) {
assert.PanicsWithError(t, "ring buffer: size must be greater or equal to zero", func() {
aghalg.NewRingBuffer[int](-5)
})
})
t.Run("zero", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](0)
for i := 0; i < 10; i++ {
b.Append(i)
assert.Equal(t, 0, b.Len())
assert.Empty(t, elements(b, b.Len(), false))
assert.Empty(t, elements(b, b.Len(), true))
}
})
t.Run("single", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](1)
for i := 0; i < 10; i++ {
b.Append(i)
assert.Equal(t, 1, b.Len())
assert.Equal(t, []int{i}, elements(b, b.Len(), false))
assert.Equal(t, []int{i}, elements(b, b.Len(), true))
}
})
}
func TestRingBuffer_Range(t *testing.T) {
const size = 5
b := aghalg.NewRingBuffer[int](size)
testCases := []struct {
name string
want []int
count int
length int
}{{
name: "three",
count: 3,
length: 3,
want: []int{0, 1, 2},
}, {
name: "ten",
count: 10,
length: size,
want: []int{5, 6, 7, 8, 9},
}, {
name: "hundred",
count: 100,
length: size,
want: []int{95, 96, 97, 98, 99},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for i := 0; i < tc.count; i++ {
b.Append(i)
}
bufLen := b.Len()
assert.Equal(t, tc.length, bufLen)
want := tc.want
assert.Equal(t, want, elements(b, bufLen, false))
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, false))
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, false))
want = want[:cap(want)]
slices.Reverse(want)
assert.Equal(t, want, elements(b, bufLen, true))
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, true))
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, true))
})
}
}
func TestRingBuffer_Range_increment(t *testing.T) {
const size = 5
b := aghalg.NewRingBuffer[int](size)
testCases := []struct {
name string
want []int
}{{
name: "one",
want: []int{0},
}, {
name: "two",
want: []int{0, 1},
}, {
name: "three",
want: []int{0, 1, 2},
}, {
name: "four",
want: []int{0, 1, 2, 3},
}, {
name: "five",
want: []int{0, 1, 2, 3, 4},
}, {
name: "six",
want: []int{1, 2, 3, 4, 5},
}, {
name: "seven",
want: []int{2, 3, 4, 5, 6},
}, {
name: "eight",
want: []int{3, 4, 5, 6, 7},
}, {
name: "nine",
want: []int{4, 5, 6, 7, 8},
}, {
name: "ten",
want: []int{5, 6, 7, 8, 9},
}}
for i, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
b.Append(i)
assert.Equal(t, tc.want, elements(b, b.Len(), false))
slices.Reverse(tc.want)
assert.Equal(t, tc.want, elements(b, b.Len(), true))
})
}
}

View File

@@ -1,60 +0,0 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mathutil"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements the [error] interface for *LimitReachedError.
//
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReader is a wrapper for [io.Reader] limiting the input and dealing
// with errors package.
type limitedReader struct {
r io.Reader
limit int64
n int64
}
// Read implements the [io.Reader] interface.
func (lr *limitedReader) Read(p []byte) (n int, err error) {
if lr.n == 0 {
return 0, &LimitReachedError{
Limit: lr.limit,
}
}
p = p[:mathutil.Min(lr.n, int64(len(p)))]
n, err = lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after
// n bytes read.
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
if n < 0 {
return nil, errors.Error("limit must be non-negative")
}
return &limitedReader{
r: r,
limit: n,
n: n,
}, nil
}

View File

@@ -1,96 +0,0 @@
package aghio_test
import (
"io"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitReader(t *testing.T) {
testCases := []struct {
wantErrMsg string
name string
n int64
}{{
wantErrMsg: "",
name: "positive",
n: 1,
}, {
wantErrMsg: "",
name: "zero",
n: 0,
}, {
wantErrMsg: "limit must be non-negative",
name: "negative",
n: -1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := aghio.LimitReader(nil, tc.n)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestLimitedReader_Read(t *testing.T) {
testCases := []struct {
err error
name string
rStr string
limit int64
want int
}{{
err: nil,
name: "perfectly_match",
rStr: "abc",
limit: 3,
want: 3,
}, {
err: io.EOF,
name: "eof",
rStr: "",
limit: 3,
want: 0,
}, {
err: &aghio.LimitReachedError{
Limit: 0,
},
name: "limit_reached",
rStr: "abc",
limit: 0,
want: 0,
}, {
err: nil,
name: "truncated",
rStr: "abc",
limit: 2,
want: 2,
}}
for _, tc := range testCases {
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
lreader, err := aghio.LimitReader(readCloser, tc.limit)
require.NoError(t, err)
require.NotNil(t, lreader)
t.Run(tc.name, func(t *testing.T) {
buf := make([]byte, tc.limit+1)
n, rerr := lreader.Read(buf)
require.Equal(t, rerr, tc.err)
assert.Equal(t, tc.want, n)
})
}
}
func TestLimitedReader_LimitReachedError(t *testing.T) {
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
Limit: 0,
})
}

View File

@@ -12,12 +12,12 @@ import (
// listenPacketReusable announces on the local network address additionally
// configuring the socket to have a reusable binding.
func listenPacketReusable(ifaceName, network, address string) (c net.PacketConn, err error) {
var port int
var port uint16
_, port, err = netutil.SplitHostPort(address)
if err != nil {
return nil, err
}
// TODO(e.burkov): Inspect nclient4.NewRawUDPConn and implement here.
return nclient4.NewRawUDPConn(ifaceName, port)
return nclient4.NewRawUDPConn(ifaceName, int(port))
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"net/url"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@@ -263,7 +264,7 @@ func IsAddrInUse(err error) (ok bool) {
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) {
func CollectAllIfacesAddrs() (addrs []netip.Addr, err error) {
var ifaceAddrs []net.Addr
ifaceAddrs, err = netInterfaceAddrs()
if err != nil {
@@ -271,19 +272,41 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
}
for _, addr := range ifaceAddrs {
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
var p netip.Prefix
p, err = netip.ParsePrefix(addr.String())
if err != nil {
return nil, fmt.Errorf("parsing cidr: %w", err)
// Don't wrap the error since it's informative enough as is.
return nil, err
}
addrs = append(addrs, ip.String())
addrs = append(addrs, p.Addr())
}
return addrs, nil
}
// ParseAddrPort parses an [netip.AddrPort] from s, which should be either a
// valid IP, optionally with port, or a valid URL with plain IP address. The
// defaultPort is used if s doesn't contain port number.
func ParseAddrPort(s string, defaultPort uint16) (ipp netip.AddrPort, err error) {
u, err := url.Parse(s)
if err == nil && u.Host != "" {
s = u.Host
}
ipp, err = netip.ParseAddrPort(s)
if err != nil {
ip, parseErr := netip.ParseAddr(s)
if parseErr != nil {
return ipp, errors.Join(err, parseErr)
}
return netip.AddrPortFrom(ip, defaultPort), nil
}
return ipp, nil
}
// BroadcastFromPref calculates the broadcast IP address for p.
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
bc = p.Addr().Unmap()

View File

@@ -230,7 +230,7 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
wantAddrs []netip.Addr
}{{
name: "success",
wantErrMsg: ``,
@@ -241,10 +241,13 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
wantAddrs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("4.3.2.1"),
},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
wantErrMsg: `netip.ParsePrefix("1.2.3.4"): no '/'`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
@@ -269,12 +272,11 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
assert.ErrorIs(t, err, errAddrs)
})
}

View File

@@ -2,10 +2,16 @@ package aghnet_test
import (
"io/fs"
"net"
"net/netip"
"net/url"
"os"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
@@ -14,3 +20,76 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
func TestParseAddrPort(t *testing.T) {
const defaultPort = 1
v4addr := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
name string
input string
wantErrMsg string
want netip.AddrPort
}{{
name: "success_ip",
input: v4addr.String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, defaultPort),
}, {
name: "success_ip_port",
input: netutil.JoinHostPort(v4addr.String(), 5),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, 5),
}, {
name: "success_url",
input: (&url.URL{
Scheme: "tcp",
Host: v4addr.String(),
}).String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, defaultPort),
}, {
name: "success_url_port",
input: (&url.URL{
Scheme: "tcp",
Host: netutil.JoinHostPort(v4addr.String(), 5),
}).String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, 5),
}, {
name: "error_invalid_ip",
input: "256.256.256.256",
wantErrMsg: `not an ip:port
ParseAddr("256.256.256.256"): IPv4 field has value >255`,
want: netip.AddrPort{},
}, {
name: "error_invalid_port",
input: net.JoinHostPort(v4addr.String(), "-5"),
wantErrMsg: `invalid port "-5" parsing "1.2.3.4:-5"
ParseAddr("1.2.3.4:-5"): unexpected character (at ":-5")`,
want: netip.AddrPort{},
}, {
name: "error_invalid_url",
input: "tcp:://1.2.3.4",
wantErrMsg: `invalid port "//1.2.3.4" parsing "tcp:://1.2.3.4"
ParseAddr("tcp:://1.2.3.4"): each colon-separated field must have at least ` +
`one digit (at "tcp:://1.2.3.4")`,
want: netip.AddrPort{},
}, {
name: "empty",
input: "",
want: netip.AddrPort{},
wantErrMsg: `not an ip:port
ParseAddr(""): unable to parse IP`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ap, err := aghnet.ParseAddrPort(tc.input, defaultPort)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, ap)
})
}
}

View File

@@ -1,36 +0,0 @@
package aghnet
// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
// const DefaultRefreshIvl = 5 * time.Minute
// HostGenFunc is the signature for functions generating fake hostnames. The
// implementation must be safe for concurrent use.
type HostGenFunc func() (host string)
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It must be safe for
// concurrent use.
Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It must be safe
// for concurrent use.
refresh() (err error)
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) {
sr = newSystemResolvers(hostGenFunc)
// Fill cache.
err = sr.refresh()
if err != nil {
return nil, err
}
return sr, nil
}

View File

@@ -1,146 +0,0 @@
//go:build !windows
package aghnet
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
)
// defaultHostGen is the default method of generating host for Refresh.
func defaultHostGen() (host string) {
// TODO(e.burkov): Use strings.Builder.
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
}
// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
// addrsLock protects addrs.
addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
}
const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
return nil
}
return err
}
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}
s := &systemResolvers{
resolver: &net.Resolver{
PreferGo: true,
},
hostGenFunc: hostGenFunc,
addrs: stringutil.NewSet(),
}
s.resolver.Dial = s.dialFunc
return s
}
// validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
// host
case 2:
// Remove the zone and check the IP address part.
host = parts[0]
default:
return errUnexpectedHostFormat
}
if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed
}
return nil
}
// dockerEmbeddedDNS is the address of Docker's embedded DNS server.
//
// See
// https://github.com/moby/moby/blob/v1.12.0/docs/userguide/networking/dockernetworks.md.
const dockerEmbeddedDNS = "127.0.0.11"
// dialFunc gets the resolver's address and puts it into internal cache.
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = netutil.SplitHost(address)
if err != nil {
// TODO(e.burkov): Maybe use a structured errBadAddrPassed to
// allow unwrapping of the real error.
return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
}
// Exclude Docker's embedded DNS server, as it may cause recursion if
// the container is set as the host system's default DNS server.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3064.
//
// TODO(a.garipov): Perhaps only do this when we are in the container?
// Maybe use an environment variable?
if host == dockerEmbeddedDNS {
return nil, errFakeDial
}
err = validateDialedHost(host)
if err != nil {
return nil, fmt.Errorf("validating dialed host: %w", err)
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs.Add(host)
return nil, errFakeDial
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
return sr.addrs.Values()
}

View File

@@ -1,81 +0,0 @@
//go:build !windows
package aghnet
import (
"context"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolversImpl(
t *testing.T,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()
sr := createTestSystemResolvers(t, hostGenFunc)
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
}
func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, nil)
assert.NoError(t, sr.refresh())
})
t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
})
}
func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImpl(t, nil)
testCases := []struct {
want error
name string
address string
}{{
want: errFakeDial,
name: "valid_ipv4",
address: "127.0.0.1",
}, {
want: errFakeDial,
name: "valid_ipv6_port",
address: "[::1]:53",
}, {
want: errFakeDial,
name: "valid_ipv6_zone_port",
address: "[::1%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_split_host",
address: "127.0.0.1::123",
}, {
want: errUnexpectedHostFormat,
name: "invalid_ipv6_zone_port",
address: "[::1%%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_parse_ip",
address: "not-ip",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)
require.Nil(t, conn)
assert.ErrorIs(t, err, tc.want)
})
}
}

View File

@@ -1,37 +0,0 @@
package aghnet
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolvers(
t *testing.T,
hostGenFunc HostGenFunc,
) (sr SystemResolvers) {
t.Helper()
var err error
sr, err = NewSystemResolvers(hostGenFunc)
require.NoError(t, err)
require.NotNil(t, sr)
return sr
}
func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, nil)
var rs []string
require.NotPanics(t, func() {
rs = sr.Get()
})
assert.NotEmpty(t, rs)
}
// TODO(e.burkov): Write tests for refreshWithTicker.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.

View File

@@ -1,163 +0,0 @@
//go:build windows
package aghnet
import (
"bufio"
"fmt"
"io"
"net"
"os/exec"
"strings"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// systemResolvers implementation differs for Windows since Go's resolver
// doesn't work there.
//
// See https://github.com/golang/go/issues/33097.
type systemResolvers struct {
// addrs is the slice of cached local resolvers' addresses.
addrs []string
addrsLock sync.RWMutex
}
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{}
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
addrs := sr.addrs
rs = make([]string, len(addrs))
copy(rs, addrs)
return rs
}
// writeExit writes "exit" to w and closes it. It is supposed to be run in
// a goroutine.
func writeExit(w io.WriteCloser) {
defer log.OnPanic("systemResolvers: writeExit")
defer func() {
derr := w.Close()
if derr != nil {
log.Error("systemResolvers: writeExit: closing: %s", derr)
}
}()
_, err := io.WriteString(w, "exit")
if err != nil {
log.Error("systemResolvers: writeExit: writing: %s", err)
}
}
// scanAddrs scans the DNS addresses from nslookup's output. The expected
// output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
func scanAddrs(s *bufio.Scanner) (addrs []string) {
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "Address:" {
continue
}
// If the address contains port then it is separated with '#'.
ipPort := strings.Split(fields[1], "#")
if len(ipPort) == 0 {
continue
}
addr := ipPort[0]
if net.ParseIP(addr) == nil {
log.Debug("systemResolvers: %q is not a valid ip", addr)
continue
}
addrs = append(addrs, addr)
}
return addrs
}
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
//
// TODO(e.burkov): This whole function needs more detailed research on getting
// local resolvers addresses on Windows. We execute the external command for
// now that is not the most accurate way.
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
var cmdPath string
cmdPath, err = exec.LookPath("nslookup.exe")
if err != nil {
return nil, fmt.Errorf("looking up cmd path: %w", err)
}
cmd := exec.Command(cmdPath)
var stdin io.WriteCloser
stdin, err = cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
}
var stdout io.ReadCloser
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
go writeExit(stdin)
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("start command executing: %w", err)
}
s := bufio.NewScanner(stdout)
addrs = scanAddrs(s)
err = cmd.Wait()
if err != nil {
return nil, fmt.Errorf("executing the command: %w", err)
}
err = s.Err()
if err != nil {
return nil, fmt.Errorf("scanning output: %w", err)
}
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
//
// See go doc os/exec.Cmd.StdoutPipe.
return addrs, nil
}
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
got, err := sr.getAddrs()
if err != nil {
return fmt.Errorf("can't get addresses: %w", err)
}
if len(got) == 0 {
return nil
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs = got
return nil
}

View File

@@ -1,7 +0,0 @@
//go:build windows
package aghnet
// TODO(e.burkov): Write tests for Windows implementation.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.

View File

@@ -4,7 +4,7 @@ import (
"bytes"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -72,11 +72,10 @@ func TestLargestLabeled(t *testing.T) {
}
t.Run("scanner_fail", func(t *testing.T) {
lr, err := aghio.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0)
require.NoError(t, err)
lr := ioutil.LimitReader(bytes.NewReader([]byte{1, 2, 3}), 0)
target := &aghio.LimitReachedError{}
_, _, err = parsePSOutput(lr, "", nil)
target := &ioutil.LimitError{}
_, _, err := parsePSOutput(lr, "", nil)
require.ErrorAs(t, err, &target)
assert.EqualValues(t, 0, target.Limit)

View File

@@ -4,10 +4,14 @@ package aghtest
import (
"crypto/sha256"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"testing"
"github.com/AdguardTeam/golibs/log"
"github.com/stretchr/testify/require"
)
const (
@@ -51,3 +55,19 @@ func HostToIPs(host string) (ipv4, ipv6 netip.Addr) {
return netip.AddrFrom4([4]byte(hash[:4])), netip.AddrFrom16([16]byte(hash[4:20]))
}
// StartHTTPServer is a helper that starts the HTTP server, which is configured
// to return data on every request, and returns the client and server URL.
func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(data)
}))
t.Cleanup(srv.Close)
u, err := url.Parse(srv.URL)
require.NoError(t, err)
return srv.Client(), u
}

View File

@@ -57,6 +57,9 @@ type DHCPServer interface {
// RemoveStaticLease - remove a static lease
RemoveStaticLease(l *Lease) (err error)
// UpdateStaticLease updates IP, hostname of the lease.
UpdateStaticLease(l *Lease) (err error)
// FindMACbyIP returns a MAC address by the IP address of its lease, if
// there is one.
FindMACbyIP(ip netip.Addr) (mac net.HardwareAddr)

View File

@@ -5,6 +5,7 @@ package dhcpd
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
@@ -290,12 +291,12 @@ func (s *server) handleDHCPSetConfigV6(
func (s *server) createServers(conf *dhcpServerConfigJSON) (srv4, srv6 DHCPServer, err error) {
srv4, v4Enabled, err := s.handleDHCPSetConfigV4(conf)
if err != nil {
return nil, nil, fmt.Errorf("bad dhcpv4 configuration: %s", err)
return nil, nil, fmt.Errorf("bad dhcpv4 configuration: %w", err)
}
srv6, v6Enabled, err := s.handleDHCPSetConfigV6(conf)
if err != nil {
return nil, nil, fmt.Errorf("bad dhcpv6 configuration: %s", err)
return nil, nil, fmt.Errorf("bad dhcpv6 configuration: %w", err)
}
if conf.Enabled == aghalg.NBTrue && !v4Enabled && !v6Enabled {
@@ -424,7 +425,7 @@ func newNetInterfaceJSON(iface net.Interface) (out *netInterfaceJSON, err error)
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf(
"failed to get addresses for interface %s: %s",
"failed to get addresses for interface %s: %w",
iface.Name,
err,
)
@@ -590,82 +591,78 @@ func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
}
}
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
// parseLease parses a lease from r. If there is no error returns DHCPServer
// and *Lease. r must be non-nil.
func (s *server) parseLease(r io.Reader) (srv DHCPServer, lease *Lease, err error) {
l := &leaseStatic{}
err := json.NewDecoder(r.Body).Decode(l)
err = json.NewDecoder(r).Decode(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
return nil, nil, fmt.Errorf("decoding json: %w", err)
}
if !l.IP.IsValid() {
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
return nil, nil, errors.Error("invalid ip")
}
l.IP = l.IP.Unmap()
var srv DHCPServer
if l.IP.Is4() {
lease, err = l.toLease()
if err != nil {
return nil, nil, fmt.Errorf("parsing: %w", err)
}
if lease.IP.Is4() {
srv = s.srv4
} else {
srv = s.srv6
}
lease, err := l.toLease()
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
return srv, lease, nil
}
return
}
err = srv.AddStaticLease(lease)
// handleDHCPAddStaticLease is the handler for the POST
// /control/dhcp/add_static_lease HTTP API.
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.AddStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
}
// handleDHCPRemoveStaticLease is the handler for the POST
// /control/dhcp/remove_static_lease HTTP API.
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
l := &leaseStatic{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if !l.IP.IsValid() {
aghhttp.Error(r, w, http.StatusBadRequest, "invalid IP")
return
}
l.IP = l.IP.Unmap()
var srv DHCPServer
if l.IP.Is4() {
srv = s.srv4
} else {
srv = s.srv6
}
lease, err := l.toLease()
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing: %s", err)
return
}
err = srv.RemoveStaticLease(lease)
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.RemoveStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
}
// handleDHCPUpdateStaticLease is the handler for the POST
// /control/dhcp/update_static_lease HTTP API.
func (s *server) handleDHCPUpdateStaticLease(w http.ResponseWriter, r *http.Request) {
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.UpdateStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
}
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
@@ -729,6 +726,7 @@ func (s *server) registerHandlers() {
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
}

View File

@@ -0,0 +1,319 @@
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// defaultResponse is a helper that returns the response with default
// configuration.
func defaultResponse() *dhcpStatusResponse {
conf4 := defaultV4ServerConf()
conf4.LeaseDuration = 86400
resp := &dhcpStatusResponse{
V4: *conf4,
V6: V6ServerConf{},
Leases: []*leaseDynamic{},
StaticLeases: []*leaseStatic{},
Enabled: true,
}
return resp
}
// handleLease is the helper function that calls handler with provided static
// lease as body and returns modified response recorder.
func handleLease(t *testing.T, lease *leaseStatic, handler http.HandlerFunc) (w *httptest.ResponseRecorder) {
t.Helper()
w = httptest.NewRecorder()
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(lease)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
handler(w, r)
return w
}
// checkStatus is a helper that asserts the response of
// [*server.handleDHCPStatus].
func checkStatus(t *testing.T, s *server, want *dhcpStatusResponse) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(&want)
require.NoError(t, err)
r, err := http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPStatus(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, b.String(), w.Body.String())
}
func TestServer_handleDHCPStatus(t *testing.T) {
const (
staticName = "static-client"
staticMAC = "aa:aa:aa:aa:aa:aa"
)
staticIP := netip.MustParseAddr("192.168.10.10")
staticLease := &leaseStatic{
HWAddr: staticMAC,
IP: staticIP,
Hostname: staticName,
}
s, err := Create(&ServerConfig{
Enabled: true,
Conf4: *defaultV4ServerConf(),
DataDir: t.TempDir(),
ConfigModified: func() {},
})
require.NoError(t, err)
ok := t.Run("status", func(t *testing.T) {
resp := defaultResponse()
checkStatus(t, s, resp)
})
require.True(t, ok)
ok = t.Run("add_static_lease", func(t *testing.T) {
w := handleLease(t, staticLease, s.handleDHCPAddStaticLease)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
resp.StaticLeases = []*leaseStatic{staticLease}
checkStatus(t, s, resp)
})
require.True(t, ok)
ok = t.Run("add_invalid_lease", func(t *testing.T) {
w := handleLease(t, staticLease, s.handleDHCPAddStaticLease)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
require.True(t, ok)
ok = t.Run("remove_static_lease", func(t *testing.T) {
w := handleLease(t, staticLease, s.handleDHCPRemoveStaticLease)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
checkStatus(t, s, resp)
})
require.True(t, ok)
ok = t.Run("set_config", func(t *testing.T) {
w := httptest.NewRecorder()
resp := defaultResponse()
resp.Enabled = false
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&resp)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPSetConfig(w, r)
assert.Equal(t, http.StatusOK, w.Code)
checkStatus(t, s, resp)
})
require.True(t, ok)
}
func TestServer_HandleUpdateStaticLease(t *testing.T) {
const (
leaseV4Name = "static-client-v4"
leaseV4MAC = "44:44:44:44:44:44"
leaseV6Name = "static-client-v6"
leaseV6MAC = "66:66:66:66:66:66"
)
leaseV4IP := netip.MustParseAddr("192.168.10.10")
leaseV6IP := netip.MustParseAddr("2001::6")
const (
leaseV4Pos = iota
leaseV6Pos
)
leases := []*leaseStatic{
leaseV4Pos: {
HWAddr: leaseV4MAC,
IP: leaseV4IP,
Hostname: leaseV4Name,
},
leaseV6Pos: {
HWAddr: leaseV6MAC,
IP: leaseV6IP,
Hostname: leaseV6Name,
},
}
s, err := Create(&ServerConfig{
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},
DataDir: t.TempDir(),
ConfigModified: func() {},
})
require.NoError(t, err)
for _, l := range leases {
w := handleLease(t, l, s.handleDHCPAddStaticLease)
assert.Equal(t, http.StatusOK, w.Code)
}
testCases := []struct {
name string
pos int
lease *leaseStatic
}{{
name: "update_v4_name",
pos: leaseV4Pos,
lease: &leaseStatic{
HWAddr: leaseV4MAC,
IP: leaseV4IP,
Hostname: "updated-client-v4",
},
}, {
name: "update_v4_ip",
pos: leaseV4Pos,
lease: &leaseStatic{
HWAddr: leaseV4MAC,
IP: netip.MustParseAddr("192.168.10.200"),
Hostname: "updated-client-v4",
},
}, {
name: "update_v6_name",
pos: leaseV6Pos,
lease: &leaseStatic{
HWAddr: leaseV6MAC,
IP: leaseV6IP,
Hostname: "updated-client-v6",
},
}, {
name: "update_v6_ip",
pos: leaseV6Pos,
lease: &leaseStatic{
HWAddr: leaseV6MAC,
IP: netip.MustParseAddr("2001::666"),
Hostname: "updated-client-v6",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
w := handleLease(t, tc.lease, s.handleDHCPUpdateStaticLease)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
leases[tc.pos] = tc.lease
resp.StaticLeases = leases
checkStatus(t, s, resp)
})
}
}
func TestServer_HandleUpdateStaticLease_validation(t *testing.T) {
const (
leaseV4Name = "static-client-v4"
leaseV4MAC = "44:44:44:44:44:44"
anotherV4Name = "another-client-v4"
anotherV4MAC = "55:55:55:55:55:55"
)
leaseV4IP := netip.MustParseAddr("192.168.10.10")
anotherV4IP := netip.MustParseAddr("192.168.10.20")
leases := []*leaseStatic{{
HWAddr: leaseV4MAC,
IP: leaseV4IP,
Hostname: leaseV4Name,
}, {
HWAddr: anotherV4MAC,
IP: anotherV4IP,
Hostname: anotherV4Name,
}}
s, err := Create(&ServerConfig{
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},
DataDir: t.TempDir(),
ConfigModified: func() {},
})
require.NoError(t, err)
for _, l := range leases {
w := handleLease(t, l, s.handleDHCPAddStaticLease)
assert.Equal(t, http.StatusOK, w.Code)
}
testCases := []struct {
lease *leaseStatic
name string
want string
}{{
name: "v4_unknown_mac",
lease: &leaseStatic{
HWAddr: "aa:aa:aa:aa:aa:aa",
IP: leaseV4IP,
Hostname: leaseV4Name,
},
want: "dhcpv4: updating static lease: can't find lease aa:aa:aa:aa:aa:aa\n",
}, {
name: "update_v4_same_ip",
lease: &leaseStatic{
HWAddr: leaseV4MAC,
IP: anotherV4IP,
Hostname: leaseV4Name,
},
want: "dhcpv4: updating static lease: ip address is not unique\n",
}, {
name: "update_v4_same_name",
lease: &leaseStatic{
HWAddr: leaseV4MAC,
IP: leaseV4IP,
Hostname: anotherV4Name,
},
want: "dhcpv4: updating static lease: hostname is not unique\n",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
w := handleLease(t, tc.lease, s.handleDHCPUpdateStaticLease)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equal(t, tc.want, w.Body.String())
})
}
}

View File

@@ -1,159 +0,0 @@
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_handleDHCPStatus(t *testing.T) {
const (
staticName = "static-client"
staticMAC = "aa:aa:aa:aa:aa:aa"
)
staticIP := netip.MustParseAddr("192.168.10.10")
staticLease := &leaseStatic{
HWAddr: staticMAC,
IP: staticIP,
Hostname: staticName,
}
s, err := Create(&ServerConfig{
Enabled: true,
Conf4: *defaultV4ServerConf(),
DataDir: t.TempDir(),
ConfigModified: func() {},
})
require.NoError(t, err)
// checkStatus is a helper that asserts the response of
// [*server.handleDHCPStatus].
checkStatus := func(t *testing.T, want *dhcpStatusResponse) {
w := httptest.NewRecorder()
var req *http.Request
req, err = http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&want)
require.NoError(t, err)
s.handleDHCPStatus(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, b.String(), w.Body.String())
}
// defaultResponse is a helper that returns the response with default
// configuration.
defaultResponse := func() *dhcpStatusResponse {
conf4 := defaultV4ServerConf()
conf4.LeaseDuration = 86400
resp := &dhcpStatusResponse{
V4: *conf4,
V6: V6ServerConf{},
Leases: []*leaseDynamic{},
StaticLeases: []*leaseStatic{},
Enabled: true,
}
return resp
}
ok := t.Run("status", func(t *testing.T) {
resp := defaultResponse()
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("add_static_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(staticLease)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPAddStaticLease(w, r)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
resp.StaticLeases = []*leaseStatic{staticLease}
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("add_invalid_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&leaseStatic{})
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPAddStaticLease(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
require.True(t, ok)
ok = t.Run("remove_static_lease", func(t *testing.T) {
w := httptest.NewRecorder()
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(staticLease)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPRemoveStaticLease(w, r)
assert.Equal(t, http.StatusOK, w.Code)
resp := defaultResponse()
checkStatus(t, resp)
})
require.True(t, ok)
ok = t.Run("set_config", func(t *testing.T) {
w := httptest.NewRecorder()
resp := defaultResponse()
resp.Enabled = false
b := &bytes.Buffer{}
err = json.NewEncoder(b).Encode(&resp)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", b)
require.NoError(t, err)
s.handleDHCPSetConfig(w, r)
assert.Equal(t, http.StatusOK, w.Code)
checkStatus(t, resp)
})
require.True(t, ok)
}

View File

@@ -43,6 +43,7 @@ func (s *server) registerHandlers() {
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
}

View File

@@ -19,6 +19,7 @@ func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
func (winServer) getLeasesRef() []*Lease { return nil }
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
func (winServer) UpdateStaticLease(_ *Lease) (err error) { return nil }
func (winServer) FindMACbyIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}

View File

@@ -148,6 +148,9 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
return nil
}
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
s.leasedOffsets = newBitSet()
s.hostsIndex = make(map[string]*Lease, len(leases))
s.ipIndex = make(map[netip.Addr]*Lease, len(leases))
@@ -182,16 +185,13 @@ func (s *v4Server) isBlocklisted(l *Lease) (ok bool) {
return false
}
ok = true
for _, b := range l.HWAddr {
if b != 0 {
ok = false
break
return false
}
}
return ok
return true
}
// GetLeases returns the list of current DHCP leases. It is safe for concurrent
@@ -309,9 +309,15 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
return nil
}
// ErrDupHostname is returned by addLease when the added lease has a not empty
// non-unique hostname.
const ErrDupHostname = errors.Error("hostname is not unique")
const (
// ErrDupHostname is returned by addLease, validateStaticLease when the
// modified lease has a not empty non-unique hostname.
ErrDupHostname = errors.Error("hostname is not unique")
// ErrDupIP is returned by addLease, validateStaticLease when the modified
// lease has a non-unique IP address.
ErrDupIP = errors.Error("ip address is not unique")
)
// addLease adds a dynamic or static lease.
func (s *v4Server) addLease(l *Lease) (err error) {
@@ -428,6 +434,81 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
return nil
}
// UpdateStaticLease updates IP, hostname of the static lease.
func (s *v4Server) UpdateStaticLease(l *Lease) (err error) {
defer func() {
if err != nil {
err = errors.Annotate(err, "dhcpv4: updating static lease: %w")
return
}
s.conf.notify(LeaseChangedDBStore)
s.conf.notify(LeaseChangedRemovedStatic)
}()
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
found := s.findLease(l.HWAddr)
if found == nil {
return fmt.Errorf("can't find lease %s", l.HWAddr)
}
err = s.validateStaticLease(l)
if err != nil {
return err
}
err = s.rmLease(found)
if err != nil {
return fmt.Errorf("removing previous lease for %s (%s): %w", l.IP, l.HWAddr, err)
}
err = s.addLease(l)
if err != nil {
return fmt.Errorf("adding updated static lease for %s (%s): %w", l.IP, l.HWAddr, err)
}
return nil
}
// validateStaticLease returns an error if the static lease is invalid.
func (s *v4Server) validateStaticLease(l *Lease) (err error) {
hostname, err := normalizeHostname(l.Hostname)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = netutil.ValidateHostname(hostname)
if err != nil {
return fmt.Errorf("validating hostname: %w", err)
}
dup, ok := s.hostsIndex[hostname]
if ok && !bytes.Equal(dup.HWAddr, l.HWAddr) {
return ErrDupHostname
}
dup, ok = s.ipIndex[l.IP]
if ok && !bytes.Equal(dup.HWAddr, l.HWAddr) {
return ErrDupIP
}
l.Hostname = hostname
if gwIP := s.conf.GatewayIP; gwIP == l.IP {
return fmt.Errorf("can't assign the gateway IP %q to the lease", gwIP)
}
if sn := s.conf.subnet; !sn.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
}
return nil
}
// updateStaticLease safe removes dynamic lease with the same properties and
// then adds a static lease l.
func (s *v4Server) updateStaticLease(l *Lease) (err error) {

View File

@@ -90,6 +90,9 @@ func (s *v6Server) IPByHost(host string) (ip netip.Addr) {
func (s *v6Server) ResetLeases(leases []*Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
s.leases = nil
for _, l := range leases {
ip := net.IP(l.IP.AsSlice())
@@ -232,6 +235,37 @@ func (s *v6Server) AddStaticLease(l *Lease) (err error) {
return nil
}
// UpdateStaticLease updates IP, hostname of the static lease.
func (s *v6Server) UpdateStaticLease(l *Lease) (err error) {
defer func() {
if err != nil {
err = errors.Annotate(err, "dhcpv6: updating static lease: %w")
return
}
s.conf.notify(LeaseChangedDBStore)
s.conf.notify(LeaseChangedRemovedStatic)
}()
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
found := s.findLease(l.HWAddr)
if found == nil {
return fmt.Errorf("can't find lease %s", l.HWAddr)
}
err = s.rmLease(found)
if err != nil {
return fmt.Errorf("removing previous lease for %s (%s): %w", l.IP, l.HWAddr, err)
}
s.addLease(l)
return nil
}
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
func (s *v6Server) RemoveStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
@@ -283,16 +317,14 @@ func (s *v6Server) rmLease(lease *Lease) (err error) {
return fmt.Errorf("lease not found")
}
// Find lease by MAC
func (s *v6Server) findLease(mac net.HardwareAddr) *Lease {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
// Find lease by MAC.
func (s *v6Server) findLease(mac net.HardwareAddr) (lease *Lease) {
for i := range s.leases {
if bytes.Equal(mac, s.leases[i].HWAddr) {
return s.leases[i]
}
}
return nil
}
@@ -474,7 +506,14 @@ func (s *v6Server) process(msg *dhcpv6.Message, req, resp dhcpv6.DHCPv6) bool {
return false
}
lease := s.findLease(mac)
var lease *Lease
func() {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
lease = s.findLease(mac)
}()
if lease == nil {
log.Debug("dhcpv6: no lease for: %s", mac)

View File

@@ -12,10 +12,12 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -392,6 +394,122 @@ func (s *Server) prepareIpsetListSettings() (err error) {
return s.ipset.init(ipsets)
}
// collectListenAddr adds addrPort to addrs. It also adds its port to
// unspecPorts if its address is unspecified.
func collectListenAddr(
addrPort netip.AddrPort,
addrs map[netip.AddrPort]unit,
unspecPorts map[uint16]unit,
) {
if addrPort == (netip.AddrPort{}) {
return
}
addrs[addrPort] = unit{}
if addrPort.Addr().IsUnspecified() {
unspecPorts[addrPort.Port()] = unit{}
}
}
// collectDNSAddrs returns configured set of listening addresses. It also
// returns a set of ports of each unspecified listening address.
func (conf *ServerConfig) collectDNSAddrs() (
addrs map[netip.AddrPort]unit,
unspecPorts map[uint16]unit,
) {
// TODO(e.burkov): Perhaps, we shouldn't allocate as much memory, since the
// TCP and UDP listening addresses are currently the same.
addrs = make(map[netip.AddrPort]unit, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
unspecPorts = map[uint16]unit{}
for _, laddr := range conf.TCPListenAddrs {
collectListenAddr(laddr.AddrPort(), addrs, unspecPorts)
}
for _, laddr := range conf.UDPListenAddrs {
collectListenAddr(laddr.AddrPort(), addrs, unspecPorts)
}
return addrs, unspecPorts
}
// defaultPlainDNSPort is the default port for plain DNS.
const defaultPlainDNSPort uint16 = 53
// addrPortMatcher is a function that matches an IP address with port.
type addrPortMatcher func(addr netip.AddrPort) (ok bool)
// filterOut filters out all the upstreams that match um. It returns all the
// closing errors joined.
func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
var errs []error
delFunc := func(u upstream.Upstream) (ok bool) {
// TODO(e.burkov): We should probably consider the protocol of u to
// only filter out the listening addresses of the same protocol.
addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort)
if parseErr != nil || !m(addr) {
// Don't filter out the upstream if it either cannot be parsed, or
// does not match um.
return false
}
errs = append(errs, u.Close())
return true
}
upsConf.Upstreams = slices.DeleteFunc(upsConf.Upstreams, delFunc)
for d, ups := range upsConf.DomainReservedUpstreams {
upsConf.DomainReservedUpstreams[d] = slices.DeleteFunc(ups, delFunc)
}
for d, ups := range upsConf.SpecifiedDomainUpstreams {
upsConf.SpecifiedDomainUpstreams[d] = slices.DeleteFunc(ups, delFunc)
}
return errors.Join(errs...)
}
// ourAddrsMatcher returns a matcher that matches all the configured listening
// addresses.
func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
addrs, unspecPorts := conf.collectDNSAddrs()
if len(addrs) == 0 {
log.Debug("dnsforward: no listen addresses")
// Match no addresses.
return func(_ netip.AddrPort) (ok bool) { return false }, nil
}
if len(unspecPorts) == 0 {
log.Debug("dnsforward: filtering out addresses %s", addrs)
m = func(a netip.AddrPort) (ok bool) {
_, ok = addrs[a]
return ok
}
} else {
var ifaceAddrs []netip.Addr
ifaceAddrs, err = aghnet.CollectAllIfacesAddrs()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts)
m = func(a netip.AddrPort) (ok bool) {
if _, ok = unspecPorts[a.Port()]; ok {
return slices.Contains(ifaceAddrs, a.Addr())
}
return false
}
}
return m, nil
}
// prepareTLS - prepares TLS configuration for the DNS proxy
func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {

View File

@@ -25,8 +25,10 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/sysresolv"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// DefaultTimeout is the default upstream timeout
@@ -72,6 +74,11 @@ type DHCP interface {
Enabled() (ok bool)
}
type SystemResolvers interface {
// Addrs returns the list of system resolvers' addresses.
Addrs() (addrs []netip.AddrPort)
}
// Server is the main way to start a DNS server.
//
// Example:
@@ -126,7 +133,7 @@ type Server struct {
// sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving.
sysResolvers aghnet.SystemResolvers
sysResolvers SystemResolvers
// recDetector is a cache for recursive requests. It is used to detect
// and prevent recursive requests only for private upstreams.
@@ -225,9 +232,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
anonymizer: p.Anonymizer,
}
// TODO(e.burkov): Enable the refresher after the actual implementation
// passes the public testing.
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
if err != nil {
return nil, fmt.Errorf("initializing system resolvers: %w", err)
}
@@ -439,73 +444,28 @@ func (s *Server) startLocked() error {
// faster than ordinary upstreams.
const defaultLocalTimeout = 1 * time.Second
// collectDNSIPAddrs returns IP addresses the server is listening on without
// port numbers. For internal use only.
func (s *Server) collectDNSIPAddrs() (addrs []string, err error) {
addrs = make([]string, len(s.conf.TCPListenAddrs)+len(s.conf.UDPListenAddrs))
var i int
var ip net.IP
for _, addr := range s.conf.TCPListenAddrs {
if addr == nil {
continue
}
if ip = addr.IP; ip.IsUnspecified() {
return aghnet.CollectAllIfacesAddrs()
}
addrs[i] = ip.String()
i++
}
for _, addr := range s.conf.UDPListenAddrs {
if addr == nil {
continue
}
if ip = addr.IP; ip.IsUnspecified() {
return aghnet.CollectAllIfacesAddrs()
}
addrs[i] = ip.String()
i++
}
return addrs[:i], nil
}
func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error) {
var ourAddrs []string
ourAddrs, err = s.collectDNSIPAddrs()
if err != nil {
return nil, err
}
ourAddrsSet := stringutil.NewSet(ourAddrs...)
log.Debug("dnsforward: filtering out %s", ourAddrsSet.String())
// TODO(e.burkov): The approach of subtracting sets of strings is not
// really applicable here since in case of listening on all network
// interfaces we should check the whole interface's network to cut off
// all the loopback addresses as well.
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
}
// setupLocalResolvers initializes the resolvers for local addresses. For
// internal use only.
func (s *Server) setupLocalResolvers() (err error) {
bootstraps := s.conf.BootstrapDNS
resolvers := s.conf.LocalPTRResolvers
if len(resolvers) == 0 {
resolvers = s.sysResolvers.Get()
bootstraps = nil
} else {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
matcher, err := s.conf.ourAddrsMatcher()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
resolvers, err = s.filterOurDNSAddrs(resolvers)
if err != nil {
return err
bootstraps := s.conf.BootstrapDNS
resolvers := s.conf.LocalPTRResolvers
filterConfig := false
if len(resolvers) == 0 {
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
} else {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
filterConfig = true
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
@@ -514,13 +474,18 @@ func (s *Server) setupLocalResolvers() (err error) {
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return fmt.Errorf("preparing private upstreams: %w", err)
}
if filterConfig {
if err = matcher.filterOut(uc); err != nil {
return fmt.Errorf("filtering private upstreams: %w", err)
}
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: uc,

View File

@@ -65,6 +65,9 @@ type jsonDNSConfig struct {
// UpstreamMode defines the way DNS requests are constructed.
UpstreamMode *string `json:"upstream_mode"`
// BlockedResponseTTL is the TTL for blocked responses.
BlockedResponseTTL *uint32 `json:"blocked_response_ttl"`
// CacheSize in bytes.
CacheSize *uint32 `json:"cache_size"`
@@ -115,6 +118,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS)
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
blockedResponseTTL := s.dnsFilter.BlockedResponseTTL()
ratelimit := s.conf.Ratelimit
customIP := s.conf.EDNSClientSubnet.CustomIP
@@ -138,9 +142,9 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
upstreamMode = "parallel"
}
defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
defPTRUps, err := s.defaultLocalPTRUpstreams()
if err != nil {
log.Debug("getting dns configuration: %s", err)
log.Error("dnsforward: %s", err)
}
return &jsonDNSConfig{
@@ -158,6 +162,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
EDNSCSUseCustom: &useCustom,
DNSSECEnabled: &enableDNSSEC,
DisableIPv6: &aaaaDisabled,
BlockedResponseTTL: &blockedResponseTTL,
CacheSize: &cacheSize,
CacheMinTTL: &cacheMinTTL,
CacheMaxTTL: &cacheMaxTTL,
@@ -166,11 +171,29 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
ResolveClients: &resolveClients,
UsePrivateRDNS: &usePrivateRDNS,
LocalPTRUpstreams: &localPTRUpstreams,
DefaultLocalPTRUpstreams: defLocalPTRUps,
DefaultLocalPTRUpstreams: defPTRUps,
DisabledUntil: protectionDisabledUntil,
}
}
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
matcher, err := s.conf.ourAddrsMatcher()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
ups = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
ups = append(ups, r.String())
}
return ups, nil
}
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := s.getDNSConfig()
@@ -204,7 +227,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
return errors.Error("empty")
}
if _, err = upstream.NewResolver(b, nil); err != nil {
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
return err
}
}
@@ -321,6 +344,10 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
s.dnsFilter.SetBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6)
}
if dc.BlockedResponseTTL != nil {
s.dnsFilter.SetBlockedResponseTTL(*dc.BlockedResponseTTL)
}
if dc.ProtectionEnabled != nil {
s.dnsFilter.SetProtectionEnabled(*dc.ProtectionEnabled)
}

View File

@@ -28,17 +28,12 @@ import (
"github.com/stretchr/testify/require"
)
// fakeSystemResolvers is a mock aghnet.SystemResolvers implementation for
// tests.
type fakeSystemResolvers struct {
// SystemResolvers is embedded here simply to make *fakeSystemResolvers
// an aghnet.SystemResolvers without actually implementing all methods.
aghnet.SystemResolvers
}
// emptySysResolvers is an empty [SystemResolvers] implementation that always
// returns nil.
type emptySysResolvers struct{}
// Get implements the aghnet.SystemResolvers interface for *fakeSystemResolvers.
// It always returns nil.
func (fsr *fakeSystemResolvers) Get() (rs []string) {
// Addrs implements the aghnet.SystemResolvers interface for emptySysResolvers.
func (emptySysResolvers) Addrs() (addrs []netip.AddrPort) {
return nil
}
@@ -60,6 +55,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
filterConf := &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
BlockedResponseTTL: 10,
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -78,7 +74,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &fakeSystemResolvers{}
s.sysResolvers = &emptySysResolvers{}
require.NoError(t, s.Start())
testutil.CleanupAndRequireSuccess(t, s.Stop)
@@ -137,6 +133,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
filterConf := &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
BlockedResponseTTL: 10,
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
@@ -154,7 +151,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.sysResolvers = &fakeSystemResolvers{}
s.sysResolvers = &emptySysResolvers{}
defaultConf := s.conf
@@ -229,6 +226,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
}, {
name: "fallbacks",
wantSet: "",
}, {
name: "blocked_response_ttl",
wantSet: "",
}}
var data map[string]struct {
@@ -480,7 +480,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
hostsUps := (&url.URL{
Scheme: "tcp",
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
}).String()
hc, err := aghnet.NewHostsContainer(

View File

@@ -6,8 +6,8 @@ import (
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/ipset"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
@@ -15,14 +15,14 @@ import (
// ipsetCtx is the ipset context. ipsetMgr can be nil.
type ipsetCtx struct {
ipsetMgr aghnet.IpsetManager
ipsetMgr ipset.Manager
}
// init initializes the ipset context. It is not safe for concurrent use.
//
// TODO(a.garipov): Rewrite into a simple constructor?
func (c *ipsetCtx) init(ipsetConf []string) (err error) {
c.ipsetMgr, err = aghnet.NewIpsetManager(ipsetConf)
c.ipsetMgr, err = ipset.NewManager(ipsetConf)
if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
// ipset cannot currently be initialized if the server was installed
// from Snap or when the user or the binary doesn't have the required

View File

@@ -114,3 +114,74 @@ func TestIpsetCtx_process(t *testing.T) {
assert.NoError(t, err)
})
}
func TestIpsetCtx_SkipIpsetProcessing(t *testing.T) {
req4 := createTestMessage("example.com")
resp4 := &dns.Msg{
Answer: []dns.RR{&dns.A{
A: net.IP{1, 2, 3, 4},
}},
}
m := &fakeIpsetMgr{}
ictx := &ipsetCtx{
ipsetMgr: m,
}
testCases := []struct {
dctx *dnsContext
name string
want bool
}{{
name: "basic",
want: false,
dctx: &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req4,
Res: resp4,
},
responseFromUpstream: true,
},
}, {
name: "rewrite",
want: true,
dctx: &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req4,
Res: resp4,
},
responseFromUpstream: false,
},
}, {
name: "empty_req",
want: true,
dctx: &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: nil,
Res: resp4,
},
responseFromUpstream: true,
},
}, {
name: "empty_res",
want: true,
dctx: &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req4,
Res: nil,
},
responseFromUpstream: true,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := ictx.skipIpsetProcessing(tc.dctx)
assert.Equal(t, tc.want, got)
})
}
}

View File

@@ -20,6 +20,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -55,6 +56,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -90,6 +92,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,

View File

@@ -25,6 +25,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -62,6 +63,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -100,6 +102,7 @@
"blocking_mode": "refused",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -138,6 +141,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -176,6 +180,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -214,6 +219,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": true,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -254,6 +260,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": true,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -294,6 +301,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -332,6 +340,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": true,
"disable_ipv6": false,
@@ -370,6 +379,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -408,6 +418,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -446,6 +457,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -486,6 +498,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -526,6 +539,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -565,6 +579,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -603,6 +618,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -643,6 +659,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -686,6 +703,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -724,6 +742,7 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
@@ -766,6 +785,46 @@
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 10,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,
"upstream_mode": "",
"cache_size": 0,
"cache_ttl_min": 0,
"cache_ttl_max": 0,
"cache_optimistic": false,
"resolve_clients": false,
"use_private_ptr_resolvers": false,
"local_ptr_upstreams": [],
"edns_cs_use_custom": false,
"edns_cs_custom_ip": ""
}
},
"blocked_response_ttl": {
"req": {
"blocked_response_ttl": 11
},
"want": {
"upstream_dns": [
"8.8.8.8:53",
"8.8.4.4:53"
],
"upstream_dns_file": "",
"bootstrap_dns": [
"9.9.9.10",
"149.112.112.10",
"2620:fe::10",
"2620:fe::fe:10"
],
"fallback_dns": [],
"protection_enabled": true,
"protection_disabled_until": null,
"ratelimit": 0,
"blocking_mode": "default",
"blocking_ipv4": "",
"blocking_ipv6": "",
"blocked_response_ttl": 11,
"edns_cs_enabled": false,
"dnssec_enabled": false,
"disable_ipv6": false,

View File

@@ -69,8 +69,8 @@ func (s *Server) prepareUpstreamSettings() (err error) {
return nil
}
// prepareUpstreamConfig sets upstream configuration based on upstreams and
// configuration of s.
// prepareUpstreamConfig returns the upstream configuration based on upstreams
// and configuration of s.
func (s *Server) prepareUpstreamConfig(
upstreams []string,
defaultUpstreams []string,

View File

@@ -504,7 +504,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
}
defer func() { err = errors.WithDeferred(err, r.Close()) }()
bufPtr := d.bufPool.Get().(*[]byte)
bufPtr := d.bufPool.Get()
defer d.bufPool.Put(bufPtr)
p := rulelist.NewParser()
@@ -607,7 +607,7 @@ func (d *DNSFilter) load(flt *FilterYAML) (err error) {
log.Debug("filtering: file %q, id %d, length %d", fileName, flt.ID, st.Size())
bufPtr := d.bufPool.Get().(*[]byte)
bufPtr := d.bufPool.Get()
defer d.bufPool.Put(bufPtr)
p := rulelist.NewParser()

View File

@@ -26,6 +26,7 @@ import (
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
@@ -232,7 +233,7 @@ type Checker interface {
// DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct {
// bufPool is a pool of buffers used for filtering-rule list parsing.
bufPool *sync.Pool
bufPool *syncutil.Pool[[]byte]
rulesStorage *filterlist.RuleStorage
filteringEngine *urlfilter.DNSEngine
@@ -514,8 +515,19 @@ func (d *DNSFilter) BlockingMode() (mode BlockingMode, bIPv4, bIPv6 netip.Addr)
return d.conf.BlockingMode, d.conf.BlockingIPv4, d.conf.BlockingIPv6
}
// SetBlockedResponseTTL sets TTL for blocked responses.
func (d *DNSFilter) SetBlockedResponseTTL(ttl uint32) {
d.confMu.Lock()
defer d.confMu.Unlock()
d.conf.BlockedResponseTTL = ttl
}
// BlockedResponseTTL returns TTL for blocked responses.
func (d *DNSFilter) BlockedResponseTTL() (ttl uint32) {
d.confMu.Lock()
defer d.confMu.Unlock()
return d.conf.BlockedResponseTTL
}
@@ -1050,13 +1062,7 @@ func InitModule() {
// be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{
bufPool: &sync.Pool{
New: func() (buf any) {
bufVal := make([]byte, rulelist.DefaultRuleBufSize)
return &bufVal
},
},
bufPool: syncutil.NewSlicePool[byte](rulelist.DefaultRuleBufSize),
refreshLock: &sync.Mutex{},
safeBrowsingChecker: c.SafeBrowsingChecker,
parentalControlChecker: c.ParentalControlChecker,

View File

@@ -118,23 +118,22 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
// 2. wildcard > exact;
// 3. lower level wildcard > higher level wildcard;
func (rw *LegacyRewrite) Compare(b *LegacyRewrite) (res int) {
if rw.Type == dns.TypeCNAME && b.Type != dns.TypeCNAME {
return -1
} else if rw.Type != dns.TypeCNAME && b.Type == dns.TypeCNAME {
if rw.Type == dns.TypeCNAME {
if b.Type != dns.TypeCNAME {
return -1
}
} else if b.Type == dns.TypeCNAME {
return 1
}
aIsWld, bIsWld := isWildcard(rw.Domain), isWildcard(b.Domain)
if aIsWld == bIsWld {
if aIsWld, bIsWld := isWildcard(rw.Domain), isWildcard(b.Domain); aIsWld == bIsWld {
// Both are either wildcards or both aren't.
return len(rw.Domain) - len(b.Domain)
}
if aIsWld {
return len(b.Domain) - len(rw.Domain)
} else if aIsWld {
return 1
} else {
return -1
}
return -1
}
// prepareRewrites normalizes and validates all legacy DNS rewrites.

View File

@@ -80,6 +80,12 @@ func TestRewrites(t *testing.T) {
}, {
Domain: "*.issue4016.com",
Answer: "sub.issue4016.com",
}, {
Domain: "*.sub.issue6226.com",
Answer: addr2v4.String(),
}, {
Domain: "*.issue6226.com",
Answer: addr1v4.String(),
}}
require.NoError(t, d.prepareRewrites())
@@ -182,6 +188,20 @@ func TestRewrites(t *testing.T) {
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}, {
name: "issue6226",
host: "www.issue6226.com",
wantCName: "",
wantIPs: []netip.Addr{addr1v4},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue6226_sub",
host: "www.sub.issue6226.com",
wantCName: "",
wantIPs: []netip.Addr{addr2v4},
wantReason: Rewritten,
dtyp: dns.TypeA,
}}
for _, tc := range testCases {

View File

@@ -340,7 +340,7 @@ var blockedServices = []blockedService{{
}, {
ID: "bilibili",
Name: "Bilibili",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 48 48\"><path fill=\"currentColor\" d=\"M36.5,12h-7.086l3.793-3.793c0.391-0.391,0.391-1.023,0-1.414s-1.023-0.391-1.414,0L26.586,12 h-5.172l-5.207-5.207c-0.391-0.391-1.023-0.391-1.414,0s-0.391,1.023,0,1.414L18.586,12H12.5C9.467,12,7,14.467,7,17.5v15 c0,3.033,2.467,5.5,5.5,5.5h2c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5h14c0,0.829,0.671,1.5,1.5,1.5s1.5-0.671,1.5-1.5h2 c3.033,0,5.5-2.467,5.5-5.5v-15C42,14.467,39.533,12,36.5,12z M39,32.5c0,1.378-1.122,2.5-2.5,2.5h-24c-1.378,0-2.5-1.122-2.5-2.5 v-15c0-1.378,1.122-2.5,2.5-2.5h24c1.378,0,2.5,1.122,2.5,2.5V32.5z\"/><rect width=\"2.75\" height=\"7.075\" x=\"30.625\" y=\"18.463\" fill=\"currentColor\" transform=\"rotate(-71.567 32.001 22)\"/><rect width=\"7.075\" height=\"2.75\" x=\"14.463\" y=\"20.625\" fill=\"currentColor\" transform=\"rotate(-18.432 17.998 21.997)\"/><path fill=\"currentColor\" d=\"M28.033,27.526c-0.189,0.593-0.644,0.896-1.326,0.896c-0.076-0.013-0.139-0.013-0.24-0.025 c-0.013,0-0.05-0.013-0.063,0c-0.341-0.05-0.745-0.177-1.061-0.467c-0.366-0.265-0.808-0.745-0.947-1.477 c0,0-0.29,1.174-0.896,1.49c-0.076,0.05-0.164,0.114-0.253,0.164l-0.038,0.025c-0.303,0.164-0.682,0.265-1.086,0.278 c-0.568-0.051-0.947-0.328-1.136-0.821l-0.063-0.164l-1.427,0.656l0.05,0.139c0.467,1.124,1.465,1.768,2.74,1.768 c0.922,0,1.667-0.303,2.209-0.909c0.556,0.606,1.288,0.909,2.209,0.909c1.856,0,2.55-1.288,2.765-1.843l0.051-0.126l-1.427-0.657 L28.033,27.526z\"/></svg>"),
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 48 48\" fill=\"currentColor\"><path d=\"M36.5 12h-7.09l3.8-3.8a1 1 0 1 0-1.42-1.4L26.6 12H21.4l-5.2-5.2a1 1 0 1 0-1.42 1.4l3.8 3.8H12.5A5.5 5.5 0 0 0 7 17.5v15a5.5 5.5 0 0 0 5.5 5.5h2a1.5 1.5 0 1 0 3 0h14a1.5 1.5 0 1 0 3 0h2a5.5 5.5 0 0 0 5.5-5.5v-15a5.5 5.5 0 0 0-5.5-5.5ZM39 32.5a2.5 2.5 0 0 1-2.5 2.5h-24a2.5 2.5 0 0 1-2.5-2.5v-15a2.5 2.5 0 0 1 2.5-2.5h24a2.5 2.5 0 0 1 2.5 2.5v15Z\"/><path d=\"m29.08 19.58-.87 2.6 6.71 2.24.87-2.6-6.71-2.24Zm-8.16 0-6.7 2.23.86 2.61 6.71-2.23-.87-2.61Zm7.11 7.95c-.19.59-.64.9-1.32.9l-.24-.03c-.02 0-.05-.02-.07 0a1.99 1.99 0 0 1-1.06-.47 2.37 2.37 0 0 1-.94-1.48s-.3 1.18-.9 1.5l-.25.16-.04.02a2.47 2.47 0 0 1-1.09.28c-.56-.05-.94-.33-1.13-.82l-.07-.17-1.42.66.05.14a2.82 2.82 0 0 0 2.74 1.77c.92 0 1.66-.3 2.2-.91.56.6 1.3.9 2.22.9a2.82 2.82 0 0 0 2.76-1.84l.05-.12-1.43-.66-.06.17Z\"/></svg>"),
Rules: []string{
"|upos-hz-mirrorakam.akamaized.net^",
"||acg.tv^",
@@ -1709,7 +1709,6 @@ var blockedServices = []blockedService{{
"||mastodon.world^",
"||mastodon.zaclys.com^",
"||mastodonapp.uk^",
"||mastodonners.nl^",
"||mastodont.cat^",
"||mastodontech.de^",
"||mastodontti.fi^",
@@ -1741,6 +1740,7 @@ var blockedServices = []blockedService{{
"||social.anoxinon.de^",
"||social.cologne^",
"||social.dev-wiki.de^",
"||social.linux.pizza^",
"||social.politicaconciencia.org^",
"||social.vivaldi.net^",
"||stranger.social^",
@@ -1985,7 +1985,7 @@ var blockedServices = []blockedService{{
}, {
ID: "qq",
Name: "QQ",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 32 32\"><g fill=\"none\" fillRule=\"evenodd\"><path d=\"M0 0h32v32H0z\" /><g fill=\"currentColor\" fillRule=\"nonzero\"><path d=\"M11.25 32C8.342 32 6 30.74 6 29.242c0-1.497 2.342-2.757 5.25-2.757s5.25 1.26 5.25 2.757S14.158 32 11.25 32zM27 29.242c0-1.497-2.342-2.757-5.25-2.757s-5.25 1.26-5.25 2.757S18.842 32 21.75 32 27 30.74 27 29.242zM14.885 7.182c0 .63-.323 1.182-.808 1.182-.485 0-.808-.552-.808-1.182 0-.63.323-1.182.808-1.182.485 0 .808.552.808 1.182zM18.923 6c-.485 0-.808.552-.808 1.182 0 .63.323-.394.808-.394.485 0 .808 1.024.808.394S19.408 6 18.923 6z\" /><path d=\"M6.653 12.638s4.685 2.465 9.926 2.465c5.242 0 9.927-2.465 9.927-2.465.112-.09.217-.161.316-.212-.002-1.088-.078-2.026-.078-2.808C26.744 4.292 22.138 0 16.5 0S6.176 4.292 6.176 9.618v2.78c.146.042.3.113.477.24zm12.626-8.664c1.112 0 1.986 1.272 1.986 2.782s-.874 2.782-1.986 2.782c-1.111 0-1.985-1.271-1.985-2.782 0-1.51.874-2.782 1.985-2.782zm-5.558 0c1.111 0 1.985 1.272 1.985 2.782s-.874 2.782-1.985 2.782c-1.112 0-1.986-1.271-1.986-2.782 0-1.51.874-2.782 1.986-2.782zm2.779 6.624c2.912 0 5.294.464 5.294.994s-2.382 1.656-5.294 1.656c-2.912 0-5.294-1.126-5.294-1.656s2.382-.994 5.294-.994zm11.374 5.182c-.058.038-.108.076-.177.117-.159.08-5.241 3.18-11.038 3.18-1.43 0-2.7-.239-3.97-.477-.239 1.67-.239 3.259-.239 3.974 0 1.272-1.032 1.193-2.303 1.272-1.27 0-2.223.16-2.303-1.033 0-.16-.08-2.782.397-5.564-1.588-.716-2.62-1.272-2.7-1.352a3.293 3.293 0 01-.335-.216C4.012 17.55 3 19.598 3 21.223c0 3.815 1.112 3.418 1.112 3.418.476 0 1.27-.795 1.985-1.67C7.765 27.662 11.735 31 16.5 31c4.765 0 8.735-3.338 10.403-8.028.715.874 1.509 1.669 1.985 1.669 0 0 1.112.397 1.112-3.418 0-1.588-.968-3.631-2.126-5.443z\" /></g></g></svg>"),
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 32 32\" fill=\"currentColor\"><path d=\"M11.25 32C8.35 32 6 30.74 6 29.24c0-1.5 2.34-2.75 5.25-2.75s5.25 1.26 5.25 2.75S14.16 32 11.25 32ZM27 29.24c0-1.5-2.34-2.75-5.25-2.75s-5.25 1.26-5.25 2.75S18.84 32 21.75 32 27 30.74 27 29.24ZM14.88 7.18c0 .63-.32 1.18-.8 1.18-.49 0-.81-.55-.81-1.18 0-.63.32-1.18.8-1.18.5 0 .81.55.81 1.18ZM18.93 6c-.48 0-.8.55-.8 1.18 0 .63.32-.4.8-.4.49 0 .81 1.03.81.4S19.41 6 18.93 6Z\"/><path d=\"M6.65 12.64s4.69 2.46 9.93 2.46c5.24 0 9.93-2.46 9.93-2.46.1-.1.21-.16.31-.21 0-1.1-.08-2.03-.08-2.81C26.74 4.29 22.14 0 16.5 0S6.18 4.3 6.18 9.62v2.78c.14.04.3.11.47.24Zm12.63-8.67c1.11 0 1.98 1.28 1.98 2.79s-.87 2.78-1.98 2.78c-1.11 0-1.99-1.27-1.99-2.78 0-1.51.88-2.79 1.99-2.79Zm-5.56 0c1.11 0 1.99 1.28 1.99 2.79s-.88 2.78-1.99 2.78c-1.11 0-1.99-1.27-1.99-2.78 0-1.51.88-2.79 2-2.79Zm2.78 6.63c2.91 0 5.3.46 5.3 1s-2.39 1.65-5.3 1.65c-2.91 0-5.3-1.13-5.3-1.66s2.39-1 5.3-1Zm11.37 5.18-.17.12c-.16.08-5.24 3.18-11.04 3.18-1.43 0-2.7-.24-3.97-.48-.24 1.67-.24 3.26-.24 3.97 0 1.28-1.03 1.2-2.3 1.28-1.27 0-2.23.16-2.3-1.04 0-.16-.09-2.78.4-5.56a23.87 23.87 0 0 1-2.7-1.35 3.3 3.3 0 0 1-.34-.22C4 17.55 3 19.6 3 21.22c0 3.82 1.11 3.42 1.11 3.42.48 0 1.27-.8 1.99-1.67C7.77 27.67 11.73 31 16.5 31c4.76 0 8.73-3.34 10.4-8.03.72.88 1.51 1.67 1.99 1.67 0 0 1.11.4 1.11-3.42 0-1.58-.97-3.63-2.13-5.44Z\"/></svg>"),
Rules: []string{
"||qq-video.cdn-go.cn^",
"||qq.com^$denyallow=wx.qq.com|weixin.qq.com",
@@ -2437,7 +2437,7 @@ var blockedServices = []blockedService{{
}, {
ID: "xboxlive",
Name: "Xbox Live",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 84 84\" xml:space=\"preserve\"><g transform=\"translate(-59.355 -42.513)\"><g transform=\"matrix(.26458 0 0 .26458 -145.88 -71.438)\" fill=\"currentColor\"><path d=\"m936.09 559.48s0.4 0 0 0c48.4 36.8 130.4 127.2 105.6 152.8-28.4 24.8-65.2 39.6-105.6 39.6s-77.6-14.8-105.6-39.6c-25.2-25.6 57.2-116 105.2-152.4 0-0.4 0.4-0.4 0.4-0.4zm83.6-105.2c-24.4-14.8-51.2-23.6-83.6-23.6s-59.2 8.8-83.6 23.6c-0.4 0-0.4 0.4-0.4 0.8s0.4 0.4 0.8 0.4c31.2-6.8 78.4 20 82.8 22.8h0.8c4.4-2.8 51.6-29.6 82.8-22.8 0.4 0 0.8 0 0.8-0.4s0-0.8-0.4-0.8zm-196 22.4c-0.4 0-0.4 0.4-0.8 0.4-29.2 29.2-47.2 69.6-47.2 114 0 36.4 12.4 70.4 32.8 97.2 0 0.4 0.4 0.4 0.8 0.4s0.4-0.4 0-0.8c-12.4-38 50.4-129.6 82.8-168l0.4-0.4c0-0.4 0-0.4-0.4-0.4-49.2-48.8-65.6-43.6-68.4-42.4zm156.4 42-0.4 0.4s0 0.4 0.4 0.4c32.4 38.4 94.8 130 82.8 168v0.8c0.4 0 0.8 0 0.8-0.4 20.4-26.8 32.8-60.8 32.8-97.2 0-44.4-18-84.8-47.6-114-0.4-0.4-0.4-0.4-0.8-0.4-2.4-0.8-18.8-6-68 42.4z\"/></g></g></svg>"),
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 84 84\" fill=\"currentColor\"><path d=\"M42.44 34.08c12.8 9.73 34.5 33.65 27.94 40.42a42.24 42.24 0 0 1-27.94 10.48A42.03 42.03 0 0 1 14.5 74.5c-6.67-6.77 15.13-30.69 27.83-40.32 0-.1.1-.1.1-.1ZM64.56 6.24A41.32 41.32 0 0 0 42.43 0a41.32 41.32 0 0 0-22.11 6.24c-.1 0-.1.1-.1.21s.1.11.2.11c8.26-1.8 20.75 5.3 21.91 6.03h.21c1.17-.74 13.65-7.83 21.9-6.03.12 0 .22 0 .22-.1s0-.22-.1-.22ZM12.7 12.17c-.1 0-.1.1-.21.1a42.56 42.56 0 0 0-3.81 55.88c0 .11.1.11.2.11s.11-.1 0-.21C5.62 57.99 22.23 33.75 30.8 23.6l.11-.1c0-.11 0-.11-.1-.11-13.02-12.91-17.36-11.54-18.1-11.22Zm41.38 11.11-.1.1s0 .11.1.11c8.57 10.16 25.08 34.4 21.9 44.45v.22c.11 0 .22 0 .22-.11a42.53 42.53 0 0 0 8.67-25.72 42.21 42.21 0 0 0-12.59-30.16c-.1-.1-.1-.1-.21-.1-.64-.22-4.97-1.6-18 11.21Z\"/></svg>"),
Rules: []string{
"||gamepass.com^",
"||xbox-global.ifs.windows.com^",

View File

@@ -644,24 +644,27 @@ func optionalAuthHandler(handler http.Handler) http.Handler {
return &authHandler{handler}
}
// UserAdd - add new user
func (a *Auth) UserAdd(u *webUser, password string) {
// Add adds a new user with the given password.
func (a *Auth) Add(u *webUser, password string) (err error) {
if len(password) == 0 {
return
return errors.Error("empty password")
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Error("bcrypt.GenerateFromPassword: %s", err)
return
return fmt.Errorf("generating hash: %w", err)
}
u.PasswordHash = string(hash)
a.lock.Lock()
a.users = append(a.users, *u)
a.lock.Unlock()
defer a.lock.Unlock()
log.Debug("auth: added user: %s", u.Name)
a.users = append(a.users, *u)
log.Debug("auth: added user with login %q", u.Name)
return nil
}
// findUser returns a user if there is one.

View File

@@ -47,7 +47,8 @@ func TestAuth(t *testing.T) {
s := session{}
user := webUser{Name: "name"}
a.UserAdd(&user, "password")
err := a.Add(&user, "password")
require.NoError(t, err)
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.RemoveSession("notfound")

View File

@@ -9,7 +9,7 @@ import (
"os"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log"
"github.com/josharian/native"
)
@@ -83,12 +83,7 @@ func glGetTokenDate(file string) uint32 {
}
}()
fileReader, err := aghio.LimitReader(f, MaxFileSize)
if err != nil {
log.Error("creating limited reader: %s", err)
return 0
}
fileReader := ioutil.LimitReader(f, MaxFileSize)
var dateToken uint32

View File

@@ -182,7 +182,7 @@ type httpPprofConfig struct {
// not absolutely necessary.
type dnsConfig struct {
BindHosts []netip.Addr `yaml:"bind_hosts"`
Port int `yaml:"port"`
Port uint16 `yaml:"port"`
// AnonymizeClientIP defines if clients' IP addresses should be anonymized
// in query log and statistics.
@@ -232,13 +232,13 @@ type tlsConfigSettings struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
PortDNSOverQUIC int `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
PortHTTPS uint16 `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS uint16 `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
// PortDNSCrypt is the port for DNSCrypt requests. If it's zero,
// DNSCrypt is disabled.
PortDNSCrypt int `yaml:"port_dnscrypt" json:"port_dnscrypt"`
PortDNSCrypt uint16 `yaml:"port_dnscrypt" json:"port_dnscrypt"`
// DNSCryptConfigFile is the path to the DNSCrypt config file. Must be
// set if PortDNSCrypt is not zero.
//
@@ -262,7 +262,7 @@ type queryLogConfig struct {
// MemSize is the number of entries kept in memory before they are flushed
// to disk.
MemSize uint32 `yaml:"size_memory"`
MemSize int `yaml:"size_memory"`
// Enabled defines if the query log is enabled.
Enabled bool `yaml:"enabled"`
@@ -554,10 +554,10 @@ func validateConfig() (err error) {
}
// udpPort is the port number for UDP protocol.
type udpPort int
type udpPort uint16
// tcpPort is the port number for TCP protocol.
type tcpPort int
type tcpPort uint16
// addPorts is a helper for ports validation that skips zero ports.
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {

View File

@@ -24,11 +24,9 @@ import (
// addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...netip.Addr) (res []string) {
for _, addr := range addrs {
var hostport string
if config.DNS.Port != defaultPortDNS {
hostport = netip.AddrPortFrom(addr, uint16(config.DNS.Port)).String()
} else {
hostport = addr.String()
hostport := addr.String()
if p := config.DNS.Port; p != defaultPortDNS {
hostport = netutil.JoinHostPort(hostport, p)
}
dst = append(dst, hostport)
@@ -102,7 +100,7 @@ type statusResponse struct {
Version string `json:"version"`
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
DNSPort uint16 `json:"dns_port"`
HTTPPort uint16 `json:"http_port"`
// ProtectionDisabledDuration is the duration of the protection pause in
@@ -340,7 +338,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
var (
forceHTTPS bool
serveHTTP3 bool
portHTTPS int
portHTTPS uint16
)
func() {
config.RLock()
@@ -394,7 +392,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
// httpsURL returns a copy of u for redirection to the HTTPS version, taking the
// hostname and the HTTPS port into account.
func httpsURL(u *url.URL, host string, portHTTPS int) (redirectURL *url.URL) {
func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL) {
hostPort := host
if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, portHTTPS)

View File

@@ -43,8 +43,8 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
data := getAddrsResponse{
Version: version.Version(),
WebPort: defaultPortHTTP,
DNSPort: defaultPortDNS,
WebPort: int(defaultPortHTTP),
DNSPort: int(defaultPortDNS),
}
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
@@ -64,7 +64,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
type checkConfReqEnt struct {
IP netip.Addr `json:"ip"`
Port int `json:"port"`
Port uint16 `json:"port"`
Autofix bool `json:"autofix"`
}
@@ -97,7 +97,7 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
// TODO(a.garipov): Declare all port variables anywhere as uint16.
reqPort := uint16(req.Web.Port)
reqPort := req.Web.Port
port := tcpPort(reqPort)
addPorts(tcpPorts, port)
if err = tcpPorts.Validate(); err != nil {
@@ -128,7 +128,7 @@ func (req *checkConfReq) validateDNS(
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := uint16(req.DNS.Port)
port := req.DNS.Port
switch port {
case 0:
return false, nil
@@ -142,13 +142,13 @@ func (req *checkConfReq) validateDNS(
return false, err
}
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, port))
if err != nil {
return false, err
}
}
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
if !aghnet.IsAddrInUse(err) {
return false, err
}
@@ -160,7 +160,7 @@ func (req *checkConfReq) validateDNS(
log.Error("disabling DNSStubListener: %s", err)
}
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
canAutofix = false
}
@@ -305,7 +305,7 @@ func disableDNSStubListener() error {
type applyConfigReqEnt struct {
IP netip.Addr `json:"ip"`
Port int `json:"port"`
Port uint16 `json:"port"`
}
type applyConfigReq struct {
@@ -395,14 +395,14 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
return
}
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -413,10 +413,22 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
copyInstallSettings(curConfig, config)
Context.firstRun = false
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port
u := &webUser{
Name: req.Username,
}
err = Context.auth.Add(u, req.Password)
if err != nil {
Context.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
return
}
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
@@ -430,11 +442,6 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
return
}
u := &webUser{
Name: req.Username,
}
Context.auth.UserAdd(u, req.Password)
err = config.write()
if err != nil {
Context.firstRun = true
@@ -445,8 +452,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
}
web.conf.firstRun = false
web.conf.BindHost = req.Web.IP
web.conf.BindPort = req.Web.Port
web.conf.BindAddr = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
registerControlHandlers(web)
@@ -487,9 +493,9 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
}
addrPort := config.HTTPConfig.Address
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
restartHTTP = addrPort.Addr() != req.Web.IP || addrPort.Port() != req.Web.Port
if restartHTTP {
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, req.Web.Port))
if err != nil {
return nil, false, fmt.Errorf(
"checking address %s:%d: %w",

View File

@@ -27,11 +27,11 @@ import (
// Default listening ports.
const (
defaultPortDNS = 53
defaultPortHTTP = 80
defaultPortHTTPS = 443
defaultPortQUIC = 853
defaultPortTLS = 853
defaultPortDNS uint16 = 53
defaultPortHTTP uint16 = 80
defaultPortHTTPS uint16 = 443
defaultPortQUIC uint16 = 853
defaultPortTLS uint16 = 853
)
// Called by other modules when configuration is changed
@@ -196,27 +196,27 @@ func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
if ips == nil {
return nil
}
tcpAddrs = make([]*net.TCPAddr, 0, len(ips))
for _, ip := range ips {
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
}
return tcpAddrs
}
func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
if ips == nil {
return nil
}
udpAddrs = make([]*net.UDPAddr, 0, len(ips))
for _, ip := range ips {
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
}
return udpAddrs
@@ -346,8 +346,8 @@ func getDNSEncryption() (de dnsEncryption) {
hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
addr := hostname
if tlsConf.PortHTTPS != defaultPortHTTPS {
addr = netutil.JoinHostPort(addr, tlsConf.PortHTTPS)
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
addr = netutil.JoinHostPort(addr, p)
}
de.https = (&url.URL{
@@ -357,17 +357,17 @@ func getDNSEncryption() (de dnsEncryption) {
}).String()
}
if tlsConf.PortDNSOverTLS != 0 {
if p := tlsConf.PortDNSOverTLS; p != 0 {
de.tls = (&url.URL{
Scheme: "tls",
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverTLS),
Host: netutil.JoinHostPort(hostname, p),
}).String()
}
if tlsConf.PortDNSOverQUIC != 0 {
if p := tlsConf.PortDNSOverQUIC; p != 0 {
de.quic = (&url.URL{
Scheme: "quic",
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverQUIC),
Host: netutil.JoinHostPort(hostname, p),
}).String()
}
}
@@ -494,7 +494,9 @@ func closeDNSServer() {
Context.dnsServer = nil
}
Context.filters.Close()
if Context.filters != nil {
Context.filters.Close()
}
if Context.stats != nil {
err := Context.stats.Close()

View File

@@ -9,8 +9,10 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"sync"
@@ -327,7 +329,7 @@ func setupBindOpts(opts options) (err error) {
if opts.bindPort != 0 {
config.HTTPConfig.Address = netip.AddrPortFrom(
config.HTTPConfig.Address.Addr(),
uint16(opts.bindPort),
opts.bindPort,
)
err = checkPorts()
@@ -495,8 +497,7 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA
clientFS: clientFS,
BindHost: config.HTTPConfig.Address.Addr(),
BindPort: int(config.HTTPConfig.Address.Port()),
BindAddr: config.HTTPConfig.Address,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHdrTimeout,
@@ -559,16 +560,28 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
err = setupOpts(opts)
fatalOnError(err)
execPath, err := os.Executable()
fatalOnError(errors.Annotate(err, "getting executable path: %w"))
u := &url.URL{
Scheme: "https",
// TODO(a.garipov): Make configurable.
Host: "static.adtidy.org",
Path: path.Join("adguardhome", version.Channel(), "version.json"),
}
upd := updater.NewUpdater(&updater.Config{
Client: config.Filtering.HTTPClient,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
Client: config.Filtering.HTTPClient,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
ExecPath: execPath,
VersionCheckURL: u.String(),
})
// TODO(e.burkov): This could be made earlier, probably as the option's
@@ -839,7 +852,7 @@ func loadCmdLineOpts() (opts options) {
// example:
//
// go to http://127.0.0.1:80
func printWebAddrs(proto, addr string, port int) {
func printWebAddrs(proto, addr string, port uint16) {
log.Printf("go to %s://%s", proto, netutil.JoinHostPort(addr, port))
}
@@ -851,7 +864,7 @@ func printHTTPAddresses(proto string) {
Context.tls.WriteDiskConfig(&tlsConf)
}
port := int(config.HTTPConfig.Address.Port())
port := config.HTTPConfig.Address.Port()
if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS
}

View File

@@ -4,9 +4,7 @@ import (
"io"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/ioutil"
)
// middlerware is a wrapper function signature.
@@ -23,12 +21,14 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
return wrapped
}
// defaultReqBodySzLim is the default maximum request body size.
const defaultReqBodySzLim = 64 * 1024
const (
// defaultReqBodySzLim is the default maximum request body size.
defaultReqBodySzLim = 64 * 1024
// largerReqBodySzLim is the maximum request body size for APIs expecting larger
// requests.
const largerReqBodySzLim = 4 * 1024 * 1024
// largerReqBodySzLim is the maximum request body size for APIs expecting
// larger requests.
largerReqBodySzLim = 4 * 1024 * 1024
)
// expectsLargerRequests shows if this request should use a larger body size
// limit. These are exceptions for poorly designed current APIs as well as APIs
@@ -52,20 +52,12 @@ func expectsLargerRequests(r *http.Request) (ok bool) {
// method limited.
func limitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
var szLim int64 = defaultReqBodySzLim
var szLim uint64 = defaultReqBodySzLim
if expectsLargerRequests(r) {
szLim = largerReqBodySzLim
}
var reader io.Reader
reader, err = aghio.LimitReader(r.Body, szLim)
if err != nil {
log.Error("limitRequestBody: %s", err)
return
}
reader := ioutil.LimitReader(r.Body, szLim)
// HTTP handlers aren't supposed to call r.Body.Close(), so just
// replace the body in a clone.

View File

@@ -7,13 +7,13 @@ import (
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitRequestBody(t *testing.T) {
errReqLimitReached := &aghio.LimitReachedError{
errReqLimitReached := &ioutil.LimitError{
Limit: defaultReqBodySzLim,
}

View File

@@ -42,7 +42,7 @@ type options struct {
// bindPort is the port on which to serve the HTTP UI.
//
// Deprecated: Use bindAddr.
bindPort int
bindPort uint16
// bindAddr is the address to serve the web UI on.
bindAddr netip.AddrPort
@@ -160,15 +160,11 @@ var cmdLineOpts = []cmdLineOpt{{
shortName: "h",
}, {
updateWithValue: func(o options, v string) (options, error) {
var err error
var p int
minPort, maxPort := 0, 1<<16-1
if p, err = strconv.Atoi(v); err != nil {
err = fmt.Errorf("port %q is not a number", v)
} else if p < minPort || p > maxPort {
err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
p, err := strconv.ParseUint(v, 10, 16)
if err != nil {
err = fmt.Errorf("parsing port: %w", err)
} else {
o.bindPort = p
o.bindPort = uint16(p)
}
return o, err
@@ -180,7 +176,7 @@ var cmdLineOpts = []cmdLineOpt{{
return "", false
}
return strconv.Itoa(o.bindPort), true
return strconv.Itoa(int(o.bindPort)), true
},
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
longName: "port",

View File

@@ -67,11 +67,11 @@ func TestParseBindHost(t *testing.T) {
}
func TestParseBindPort(t *testing.T) {
assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
assert.Equal(t, uint16(0), testParseOK(t).bindPort, "empty is port 0")
assert.Equal(t, uint16(65535), testParseOK(t, "-p", "65535").bindPort, "-p is port")
testParseParamMissing(t, "-p")
assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
assert.Equal(t, uint16(65535), testParseOK(t, "--port", "65535").bindPort, "--port is port")
testParseParamMissing(t, "--port")
testParseErr(t, "not an int", "-p", "x")

View File

@@ -39,8 +39,8 @@ type webConfig struct {
clientFS fs.FS
BindHost netip.Addr
BindPort int
// BindAddr is the binding address with port for plain HTTP web interface.
BindAddr netip.AddrPort
// ReadTimeout is an option to pass to http.Server for setting an
// appropriate field.
@@ -125,12 +125,12 @@ func newWebAPI(conf *webConfig) (w *webAPI) {
// available, unless the HTTPS server isn't active.
//
// TODO(a.garipov): Adapt for HTTP/3.
func webCheckPortAvailable(port int) (ok bool) {
func webCheckPortAvailable(port uint16) (ok bool) {
if Context.web.httpsServer.server != nil {
return true
}
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
return aghnet.CheckPort("tcp", addrPort) == nil
}
@@ -185,10 +185,9 @@ func (web *webAPI) start() {
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
// Create a new instance, because the Web is not usable after Shutdown.
hostStr := web.conf.BindHost.String()
web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: plain", log.DEBUG),
Addr: netutil.JoinHostPort(hostStr, web.conf.BindPort),
Addr: web.conf.BindAddr.String(),
Handler: hdlr,
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
@@ -249,7 +248,7 @@ func (web *webAPI) tlsServerLoop() {
web.httpsServer.cond.L.Unlock()
var portHTTPS int
var portHTTPS uint16
func() {
config.RLock()
defer config.RUnlock()
@@ -257,7 +256,7 @@ func (web *webAPI) tlsServerLoop() {
portHTTPS = config.TLS.PortHTTPS
}()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: addr,

View File

@@ -1,20 +1,22 @@
package aghnet
// Package ipset provides ipset functionality.
package ipset
import (
"net"
)
// IpsetManager is the ipset manager interface.
// Manager is the ipset manager interface.
//
// TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type,
// since ipset is exclusive to Linux?
type IpsetManager interface {
type Manager interface {
Add(host string, ip4s, ip6s []net.IP) (n int, err error)
Close() (err error)
}
// NewIpsetManager returns a new ipset. IPv4 addresses are added to an ipset
// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist.
// NewManager returns a new ipset manager. IPv4 addresses are added to an
// ipset with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must
// exist.
//
// The syntax of the ipsetConf is:
//
@@ -22,10 +24,10 @@ type IpsetManager interface {
//
// If ipsetConf is empty, msg and err are nil. The error is of type
// *aghos.UnsupportedError if the OS is not supported.
func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) {
func NewManager(ipsetConf []string) (mgr Manager, err error) {
if len(ipsetConf) == 0 {
return nil, nil
}
return newIpsetMgr(ipsetConf)
return newManager(ipsetConf)
}

View File

@@ -1,6 +1,6 @@
//go:build linux
package aghnet
package ipset
import (
"fmt"
@@ -31,9 +31,9 @@ import (
// 6. Run "sudo ipset list example_set". The Members field should contain the
// resolved IP addresses.
// newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
return newIpsetMgrWithDialer(ipsetConf, defaultDial)
// newManager returns a new Linux ipset manager.
func newManager(ipsetConf []string) (set Manager, err error) {
return newManagerWithDialer(ipsetConf, defaultDial)
}
// defaultDial is the default netfilter dialing function.
@@ -53,50 +53,31 @@ type ipsetConn interface {
Header(name string) (p *ipset.HeaderPolicy, err error)
}
// ipsetDialer creates an ipsetConn.
type ipsetDialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error)
// dialer creates an ipsetConn.
type dialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error)
// ipsetProps contains one Linux Netfilter ipset properties.
type ipsetProps struct {
// props contains one Linux Netfilter ipset properties.
type props struct {
name string
family netfilter.ProtoFamily
}
// unit is a convenient alias for struct{}.
type unit = struct{}
// manager is the Linux Netfilter ipset manager.
type manager struct {
nameToIpset map[string]props
domainToIpsets map[string][]props
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
type ipsInIpset map[ipInIpsetEntry]unit
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
type ipInIpsetEntry struct {
ipsetName string
ipArr [net.IPv6len]byte
}
// ipsetMgr is the Linux Netfilter ipset manager.
type ipsetMgr struct {
nameToIpset map[string]ipsetProps
domainToIpsets map[string][]ipsetProps
dial ipsetDialer
dial dialer
// mu protects all properties below.
mu *sync.Mutex
// TODO(a.garipov): Currently, the ipset list is static, and we don't
// read the IPs already in sets, so we can assume that all incoming IPs
// are either added to all corresponding ipsets or not. When that stops
// being the case, for example if we add dynamic reconfiguration of
// ipsets, this map will need to become a per-ipset-name one.
addedIPs ipsInIpset
ipv4Conn ipsetConn
ipv6Conn ipsetConn
}
// dialNetfilter establishes connections to Linux's netfilter module.
func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) {
func (m *manager) dialNetfilter(conf *netlink.Config) (err error) {
// The kernel API does not actually require two sockets but package
// github.com/digineo/go-ipset does.
//
@@ -145,7 +126,7 @@ func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
}
// ipsetProps returns the properties of an ipset with the given name.
func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) {
func (m *manager) ipsetProps(name string) (set props, err error) {
// The family doesn't seem to matter when we use a header query, so
// query only the IPv4 one.
//
@@ -165,14 +146,14 @@ func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) {
return set, fmt.Errorf("unexpected ipset family %d", family)
}
return ipsetProps{
return props{
name: name,
family: family,
}, nil
}
// ipsets returns currently known ipsets.
func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) {
func (m *manager) ipsets(names []string) (sets []props, err error) {
for _, name := range names {
set, ok := m.nameToIpset[name]
if ok {
@@ -193,20 +174,18 @@ func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) {
return sets, nil
}
// newIpsetMgrWithDialer returns a new Linux ipset manager using the provided
// newManagerWithDialer returns a new Linux ipset manager using the provided
// dialer.
func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManager, err error) {
func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err error) {
defer func() { err = errors.Annotate(err, "ipset: %w") }()
m := &ipsetMgr{
m := &manager{
mu: &sync.Mutex{},
nameToIpset: make(map[string]ipsetProps),
domainToIpsets: make(map[string][]ipsetProps),
nameToIpset: make(map[string]props),
domainToIpsets: make(map[string][]props),
dial: dial,
addedIPs: make(ipsInIpset),
}
err = m.dialNetfilter(&netlink.Config{})
@@ -229,7 +208,7 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
return nil, fmt.Errorf("config line at idx %d: %w", i, err)
}
var ipsets []ipsetProps
var ipsets []props
ipsets, err = m.ipsets(ipsetNames)
if err != nil {
return nil, fmt.Errorf(
@@ -249,7 +228,7 @@ func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManag
// lookupHost find the ipsets for the host, taking subdomain wildcards into
// account.
func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) {
func (m *manager) lookupHost(host string) (sets []props) {
// Search for matching ipset hosts starting with most specific domain.
// We could use a trie here but the simple, inefficient solution isn't
// that expensive: ~10 ns for TLD + SLD vs. ~140 ns for 10 subdomains on
@@ -274,25 +253,14 @@ func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) {
// addIPs adds the IP addresses for the host to the ipset. set must be same
// family as set's family.
func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err error) {
func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error) {
if len(ips) == 0 {
return 0, nil
}
var entries []*ipset.Entry
var newAddedEntries []ipInIpsetEntry
for _, ip := range ips {
e := ipInIpsetEntry{
ipsetName: set.name,
}
copy(e.ipArr[:], ip.To16())
if _, added := m.addedIPs[e]; added {
continue
}
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
newAddedEntries = append(newAddedEntries, e)
}
n = len(entries)
@@ -315,21 +283,15 @@ func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
}
// Only add these to the cache once we're sure that all of them were
// actually sent to the ipset.
for _, e := range newAddedEntries {
m.addedIPs[e] = unit{}
}
return n, nil
}
// addToSets adds the IP addresses to the corresponding ipset.
func (m *ipsetMgr) addToSets(
func (m *manager) addToSets(
host string,
ip4s []net.IP,
ip6s []net.IP,
sets []ipsetProps,
sets []props,
) (n int, err error) {
for _, set := range sets {
var nn int
@@ -356,8 +318,8 @@ func (m *ipsetMgr) addToSets(
return n, nil
}
// Add implements the IpsetManager interface for *ipsetMgr
func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
// Add implements the [Manager] interface for *manager.
func (m *manager) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -371,8 +333,8 @@ func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
return m.addToSets(host, ip4s, ip6s, sets)
}
// Close implements the IpsetManager interface for *ipsetMgr.
func (m *ipsetMgr) Close() (err error) {
// Close implements the [Manager] interface for *manager.
func (m *manager) Close() (err error) {
m.mu.Lock()
defer m.mu.Unlock()

View File

@@ -1,6 +1,6 @@
//go:build linux
package aghnet
package ipset
import (
"net"
@@ -15,16 +15,16 @@ import (
"github.com/ti-mo/netfilter"
)
// fakeIpsetConn is a fake ipsetConn for tests.
type fakeIpsetConn struct {
// fakeConn is a fake ipsetConn for tests.
type fakeConn struct {
ipv4Header *ipset.HeaderPolicy
ipv4Entries *[]*ipset.Entry
ipv6Header *ipset.HeaderPolicy
ipv6Entries *[]*ipset.Entry
}
// Add implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) {
// Add implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
if strings.Contains(name, "ipv4") {
*c.ipv4Entries = append(*c.ipv4Entries, entries...)
@@ -38,13 +38,13 @@ func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) {
return errors.Error("test: ipset not found")
}
// Close implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Close() (err error) {
// Close implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Close() (err error) {
return nil
}
// Header implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
// Header implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
if strings.Contains(name, "ipv4") {
return c.ipv4Header, nil
} else if strings.Contains(name, "ipv6") {
@@ -54,7 +54,7 @@ func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
return nil, errors.Error("test: ipset not found")
}
func TestIpsetMgr_Add(t *testing.T) {
func TestManager_Add(t *testing.T) {
ipsetConf := []string{
"example.com,example.net/ipv4set",
"example.org,example.biz/ipv6set",
@@ -67,7 +67,7 @@ func TestIpsetMgr_Add(t *testing.T) {
pf netfilter.ProtoFamily,
conf *netlink.Config,
) (conn ipsetConn, err error) {
return &fakeIpsetConn{
return &fakeConn{
ipv4Header: &ipset.HeaderPolicy{
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv4)),
},
@@ -79,7 +79,7 @@ func TestIpsetMgr_Add(t *testing.T) {
}, nil
}
m, err := newIpsetMgrWithDialer(ipsetConf, fakeDial)
m, err := newManagerWithDialer(ipsetConf, fakeDial)
require.NoError(t, err)
ip4 := net.IP{1, 2, 3, 4}
@@ -114,21 +114,22 @@ func TestIpsetMgr_Add(t *testing.T) {
assert.NoError(t, err)
}
var ipsetPropsSink []ipsetProps
// ipsetPropsSink is the typed sink for benchmark results.
var ipsetPropsSink []props
func BenchmarkIpsetMgr_lookupHost(b *testing.B) {
propsLong := []ipsetProps{{
func BenchmarkManager_LookupHost(b *testing.B) {
propsLong := []props{{
name: "example.com",
family: netfilter.ProtoIPv4,
}}
propsShort := []ipsetProps{{
propsShort := []props{{
name: "example.net",
family: netfilter.ProtoIPv4,
}}
m := &ipsetMgr{
domainToIpsets: map[string][]ipsetProps{
m := &manager{
domainToIpsets: map[string][]props{
"": propsLong,
"example.net": propsShort,
},

View File

@@ -1,11 +1,11 @@
//go:build !linux
package aghnet
package ipset
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func newIpsetMgr(_ []string) (mgr IpsetManager, err error) {
func newManager(_ []string) (mgr Manager, err error) {
return nil, aghos.Unsupported("ipset")
}

View File

@@ -7,6 +7,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
@@ -29,12 +30,12 @@ type queryLog struct {
findClient func(ids []string) (c *Client, err error)
// logFile is the path to the log file.
logFile string
// buffer contains recent log entries. The entries in this buffer must not
// be modified.
buffer []*logEntry
buffer *aghalg.RingBuffer[*logEntry]
// logFile is the path to the log file.
logFile string
// bufferLock protects buffer.
bufferLock sync.RWMutex
@@ -195,7 +196,7 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
// Add implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Add(params *AddParams) {
var isEnabled, fileIsEnabled bool
var memSize uint32
var memSize int
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
@@ -204,7 +205,7 @@ func (l *queryLog) Add(params *AddParams) {
memSize = l.conf.MemSize
}()
if !isEnabled {
if !isEnabled || memSize == 0 {
return
}
@@ -221,36 +222,18 @@ func (l *queryLog) Add(params *AddParams) {
entry := newLogEntry(params)
needFlush := false
func() {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
l.buffer = append(l.buffer, entry)
l.buffer.Append(entry)
if !fileIsEnabled {
if len(l.buffer) > int(memSize) {
// Writing to file is disabled, so just remove the oldest entry
// from the slices.
//
// TODO(a.garipov): This should be replaced by a proper ring
// buffer, but it's currently difficult to do that.
l.buffer[0] = nil
l.buffer = l.buffer[1:]
}
} else if !l.flushPending {
needFlush = len(l.buffer) >= int(memSize)
if needFlush {
l.flushPending = true
}
}
}()
if !l.flushPending && fileIsEnabled && l.buffer.Len() >= memSize {
l.flushPending = true
if needFlush {
go func() {
flushErr := l.flushLogBuffer()
if flushErr != nil {
log.Error("querylog: flushing after adding: %s", err)
log.Error("querylog: flushing after adding: %s", flushErr)
}
}()
}

View File

@@ -7,6 +7,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -62,7 +63,7 @@ type Config struct {
// MemSize is the number of entries kept in a memory buffer before they are
// flushed to disk.
MemSize uint32
MemSize int
// Enabled tells if the query log is enabled.
Enabled bool
@@ -142,9 +143,15 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
}
}
if conf.MemSize < 0 {
return nil, errors.Error("memory size must be greater or equal to zero")
}
l = &queryLog{
findClient: findClient,
buffer: aghalg.NewRingBuffer[*logEntry](conf.MemSize),
conf: &Config{},
confMu: &sync.RWMutex{},
logFile: filepath.Join(conf.BaseDir, queryLogFileName),

View File

@@ -14,68 +14,75 @@ import (
// flushLogBuffer flushes the current buffer to file and resets the current
// buffer.
func (l *queryLog) flushLogBuffer() (err error) {
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
var flushBuffer []*logEntry
func() {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
b, err := l.encodeEntries()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
flushBuffer = l.buffer
l.buffer = nil
l.flushPending = false
}()
err = l.flushToFile(flushBuffer)
return errors.Annotate(err, "writing to file: %w")
return l.flushToFile(b)
}
// flushToFile saves the specified log entries to the query log file
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
if len(buffer) == 0 {
log.Debug("querylog: nothing to write to a file")
// encodeEntries returns JSON encoded log entries, logs estimated time, clears
// the log buffer.
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
return nil
bufLen := l.buffer.Len()
if bufLen == 0 {
return nil, errors.Error("nothing to write to a file")
}
start := time.Now()
var b bytes.Buffer
e := json.NewEncoder(&b)
for _, entry := range buffer {
err = e.Encode(entry)
if err != nil {
log.Error("Failed to marshal entry: %s", err)
b = &bytes.Buffer{}
e := json.NewEncoder(b)
return err
}
l.buffer.Range(func(entry *logEntry) (cont bool) {
err = e.Encode(entry)
return err == nil
})
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen))
var zb bytes.Buffer
filename := l.logFile
zb = b
l.buffer.Clear()
l.flushPending = false
return b, nil
}
// flushToFile saves the encoded log entries to the query log file.
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
l.fileWriteLock.Lock()
defer l.fileWriteLock.Unlock()
filename := l.logFile
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
if err != nil {
log.Error("failed to create file \"%s\": %s", filename, err)
return err
return fmt.Errorf("creating file %q: %w", filename, err)
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
n, err := f.Write(zb.Bytes())
n, err := f.Write(b.Bytes())
if err != nil {
log.Error("Couldn't write to file: %s", err)
return err
return fmt.Errorf("writing to file %q: %w", filename, err)
}
log.Debug("querylog: ok \"%s\": %v bytes written", filename, n)
log.Debug("querylog: ok %q: %v bytes written", filename, n)
return nil
}

View File

@@ -51,13 +51,12 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
// Go through the buffer in the reverse order, from newer to older.
var err error
for i := len(l.buffer) - 1; i >= 0; i-- {
l.buffer.ReverseRange(func(entry *logEntry) (cont bool) {
// A shallow clone is enough, since the only thing that this loop
// modifies is the client field.
e := l.buffer[i].shallowClone()
e := entry.shallowClone()
var err error
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil {
msg := "querylog: enriching memory record at time %s" +
@@ -70,9 +69,11 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
if params.match(e) {
entries = append(entries, e)
}
}
return entries, len(l.buffer)
return true
})
return entries, l.buffer.Len()
}
// search - searches log entries in the query log using specified parameters

View File

@@ -8,8 +8,8 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
@@ -51,11 +51,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
}
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxResponseSize)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
}
r := ioutil.LimitReader(resp.Body, MaxResponseSize)
// This use of ReadAll is safe, because we just limited the appropriate
// ReadCloser.

View File

@@ -0,0 +1,148 @@
package updater_test
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpdater_VersionInfo(t *testing.T) {
const jsonData = `{
"version": "v0.103.0-beta.2",
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
"selfupdate_min_version": "v0.0",
"download_windows_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_amd64.zip",
"download_windows_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_386.zip",
"download_darwin_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
"download_darwin_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_386.zip",
"download_linux_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
"download_linux_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
"download_linux_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
"download_linux_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
"download_linux_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
"download_linux_mips": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
"download_linux_mipsle": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
"download_linux_mips64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
"download_linux_mips64le": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
"download_freebsd_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
"download_freebsd_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
"download_freebsd_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
"download_freebsd_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
"download_freebsd_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
}`
counter := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
counter++
_, _ = w.Write([]byte(jsonData))
}))
t.Cleanup(srv.Close)
fakeURL, err := url.JoinPath(srv.URL, "adguardhome", version.ChannelBeta, "version.json")
require.NoError(t, err)
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
VersionCheckURL: fakeURL,
})
info, err := u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, counter, 1)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
t.Run("cache_check", func(t *testing.T) {
_, err = u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, counter, 1)
})
t.Run("force_check", func(t *testing.T) {
_, err = u.VersionInfo(true)
require.NoError(t, err)
assert.Equal(t, counter, 2)
})
t.Run("api_fail", func(t *testing.T) {
srv.Close()
_, err = u.VersionInfo(true)
var urlErr *url.Error
assert.ErrorAs(t, err, &urlErr)
})
}
func TestUpdater_VersionInfo_others(t *testing.T) {
const jsonData = `{
"version": "v0.103.0-beta.2",
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
"selfupdate_min_version": "v0.0",
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
"download_linux_mips_softfloat": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
}`
fakeClient, fakeURL := aghtest.StartHTTPServer(t, []byte(jsonData))
fakeURL = fakeURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
testCases := []struct {
name string
arch string
arm string
mips string
}{{
name: "ARM",
arch: "arm",
arm: "7",
mips: "",
}, {
name: "MIPS",
arch: "mips",
mips: "softfloat",
arm: "",
}}
for _, tc := range testCases {
u := updater.NewUpdater(&updater.Config{
Client: fakeClient,
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOOS: "linux",
GOARCH: tc.arch,
GOARM: tc.arm,
GOMIPS: tc.mips,
VersionCheckURL: fakeURL.String(),
})
info, err := u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
}
}

Binary file not shown.

View File

@@ -8,18 +8,16 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log"
)
@@ -36,6 +34,7 @@ type Updater struct {
workDir string
confName string
execPath string
versionCheckURL string
// mu protects all fields below.
@@ -74,18 +73,19 @@ type Config struct {
// ConfName is the name of the current configuration file. Typically,
// "AdGuardHome.yaml".
ConfName string
// WorkDir is the working directory that is used for temporary files.
WorkDir string
// ExecPath is path to the executable file.
ExecPath string
// VersionCheckURL is url to the latest version announcement.
VersionCheckURL string
}
// NewUpdater creates a new Updater.
func NewUpdater(conf *Config) *Updater {
u := &url.URL{
Scheme: "https",
// TODO(a.garipov): Make configurable.
Host: "static.adtidy.org",
Path: path.Join("adguardhome", conf.Channel, "version.json"),
}
return &Updater{
client: conf.Client,
@@ -98,7 +98,8 @@ func NewUpdater(conf *Config) *Updater {
confName: conf.ConfName,
workDir: conf.WorkDir,
versionCheckURL: u.String(),
execPath: conf.ExecPath,
versionCheckURL: conf.VersionCheckURL,
mu: &sync.RWMutex{},
}
@@ -119,12 +120,7 @@ func (u *Updater) Update(firstRun bool) (err error) {
}
}()
execPath, err := os.Executable()
if err != nil {
return fmt.Errorf("getting executable path: %w", err)
}
err = u.prepare(execPath)
err = u.prepare()
if err != nil {
return fmt.Errorf("preparing: %w", err)
}
@@ -178,7 +174,7 @@ func (u *Updater) VersionCheckURL() (vcu string) {
}
// prepare fills all necessary fields in Updater object.
func (u *Updater) prepare(exePath string) (err error) {
func (u *Updater) prepare() (err error) {
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
_, pkgNameOnly := filepath.Split(u.packageURL)
@@ -194,7 +190,7 @@ func (u *Updater) prepare(exePath string) (err error) {
updateExeName = "AdGuardHome.exe"
}
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(exePath))
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath))
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
log.Debug(
@@ -204,7 +200,7 @@ func (u *Updater) prepare(exePath string) (err error) {
u.packageURL,
)
u.currentExeName = exePath
u.currentExeName = u.execPath
_, err = os.Stat(u.currentExeName)
if err != nil {
return fmt.Errorf("checking %q: %w", u.currentExeName, err)
@@ -332,11 +328,7 @@ func (u *Updater) downloadPackageFile() (err error) {
}
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
r := ioutil.LimitReader(resp.Body, MaxPackageFileSize)
log.Debug("updater: reading http body")
// This use of ReadAll is now safe, because we limited body's Reader.

View File

@@ -0,0 +1,107 @@
package updater
import (
"os"
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpdater_internal(t *testing.T) {
wd := t.TempDir()
exePathUnix := filepath.Join(wd, "AdGuardHome.exe")
exePathWindows := filepath.Join(wd, "AdGuardHome")
yamlPath := filepath.Join(wd, "AdGuardHome.yaml")
readmePath := filepath.Join(wd, "README.md")
licensePath := filepath.Join(wd, "LICENSE.txt")
require.NoError(t, os.WriteFile(exePathUnix, []byte("AdGuardHome.exe"), 0o755))
require.NoError(t, os.WriteFile(exePathWindows, []byte("AdGuardHome"), 0o755))
require.NoError(t, os.WriteFile(yamlPath, []byte("AdGuardHome.yaml"), 0o644))
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
testCases := []struct {
name string
exeName string
os string
archiveName string
}{{
name: "unix",
os: "linux",
exeName: "AdGuardHome",
archiveName: "AdGuardHome.tar.gz",
}, {
name: "windows",
os: "windows",
exeName: "AdGuardHome.exe",
archiveName: "AdGuardHome.zip",
}}
for _, tc := range testCases {
exePath := filepath.Join(wd, tc.exeName)
// start server for returning package file
pkgData, err := os.ReadFile(filepath.Join("testdata", tc.archiveName))
require.NoError(t, err)
fakeClient, fakeURL := aghtest.StartHTTPServer(t, pkgData)
fakeURL = fakeURL.JoinPath(tc.archiveName)
u := NewUpdater(&Config{
Client: fakeClient,
GOOS: tc.os,
Version: "v0.103.0",
ExecPath: exePath,
WorkDir: wd,
ConfName: yamlPath,
})
u.newVersion = "v0.103.1"
u.packageURL = fakeURL.String()
require.NoError(t, u.prepare())
require.NoError(t, u.downloadPackageFile())
require.NoError(t, u.unpack())
require.NoError(t, u.backup(false))
require.NoError(t, u.replace())
u.clean()
// check backup files
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
require.NoError(t, err)
assert.Equal(t, tc.exeName, string(d))
// check updated files
d, err = os.ReadFile(exePath)
require.NoError(t, err)
assert.Equal(t, "1", string(d))
d, err = os.ReadFile(readmePath)
require.NoError(t, err)
assert.Equal(t, "2", string(d))
d, err = os.ReadFile(licensePath)
require.NoError(t, err)
assert.Equal(t, "3", string(d))
d, err = os.ReadFile(yamlPath)
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
}
}

View File

@@ -1,105 +1,38 @@
package updater
package updater_test
import (
"net"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"path/filepath"
"strconv"
"runtime"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Rewrite these tests.
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func startHTTPServer(data string) (l net.Listener, portStr string) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(data))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() { _ = http.Serve(listener, mux) }()
return listener, strconv.FormatUint(uint64(listener.Addr().(*net.TCPAddr).Port), 10)
}
func TestUpdateGetVersion(t *testing.T) {
func TestUpdater_Update(t *testing.T) {
const jsonData = `{
"version": "v0.103.0-beta.2",
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
"selfupdate_min_version": "v0.0",
"download_windows_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_amd64.zip",
"download_windows_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_386.zip",
"download_darwin_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
"download_darwin_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_386.zip",
"download_linux_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
"download_linux_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
"download_linux_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
"download_linux_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
"download_linux_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
"download_linux_mips": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
"download_linux_mipsle": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
"download_linux_mips64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
"download_linux_mips64le": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
"download_freebsd_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
"download_freebsd_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
"download_freebsd_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
"download_freebsd_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
"download_freebsd_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
"download_linux_amd64": "%s"
}`
l, lport := startHTTPServer(jsonData)
testutil.CleanupAndRequireSuccess(t, l.Close)
const packagePath = "/AdGuardHome.tar.gz"
u := NewUpdater(&Config{
Client: &http.Client{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
})
fakeURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", lport),
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
}
u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
// check cached
_, err = u.VersionInfo(false)
require.NoError(t, err)
}
func TestUpdate(t *testing.T) {
wd := t.TempDir()
exePath := filepath.Join(wd, "AdGuardHome")
@@ -112,55 +45,61 @@ func TestUpdate(t *testing.T) {
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
// start server for returning package file
pkgData, err := os.ReadFile("testdata/AdGuardHome.tar.gz")
pkgData, err := os.ReadFile("testdata/AdGuardHome_unix.tar.gz")
require.NoError(t, err)
l, lport := startHTTPServer(string(pkgData))
testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{
Client: &http.Client{},
Version: "v0.103.0",
mux := http.NewServeMux()
mux.HandleFunc(packagePath, func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(pkgData)
})
fakeURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", lport),
Path: "AdGuardHome.tar.gz",
}
versionPath := path.Join("/adguardhome", version.ChannelBeta, "version.json")
mux.HandleFunc(versionPath, func(w http.ResponseWriter, r *http.Request) {
var u string
u, err = url.JoinPath("http://", r.Host, packagePath)
require.NoError(t, err)
u.workDir = wd
u.confName = yamlPath
u.newVersion = "v0.103.1"
u.packageURL = fakeURL.String()
_, _ = fmt.Fprintf(w, jsonData, u)
})
require.NoError(t, u.prepare(exePath))
require.NoError(t, u.downloadPackageFile())
require.NoError(t, u.unpack())
// require.NoError(t, u.check())
require.NoError(t, u.backup(false))
require.NoError(t, u.replace())
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
u.clean()
versionCheckURL, err := url.JoinPath(srv.URL, versionPath)
require.NoError(t, err)
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
GOARCH: "amd64",
GOOS: "linux",
Version: "v0.103.0",
ConfName: yamlPath,
WorkDir: wd,
ExecPath: exePath,
VersionCheckURL: versionCheckURL,
})
_, err = u.VersionInfo(false)
require.NoError(t, err)
err = u.Update(true)
require.NoError(t, err)
// check backup files
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "LICENSE.txt"))
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
assert.Equal(t, "LICENSE.txt", string(d))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome"))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "README.md"))
require.NoError(t, err)
assert.Equal(t, "AdGuardHome", string(d))
assert.Equal(t, "README.md", string(d))
// check updated files
d, err = os.ReadFile(exePath)
_, err = os.Stat(exePath)
require.NoError(t, err)
assert.Equal(t, "1", string(d))
d, err = os.ReadFile(readmePath)
require.NoError(t, err)
@@ -175,157 +114,22 @@ func TestUpdate(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
}
func TestUpdateWindows(t *testing.T) {
wd := t.TempDir()
t.Run("config_check", func(t *testing.T) {
// TODO(s.chzhen): Test on Windows also.
if runtime.GOOS == "windows" {
t.Skip("skipping config check test on windows")
}
exePath := filepath.Join(wd, "AdGuardHome.exe")
yamlPath := filepath.Join(wd, "AdGuardHome.yaml")
readmePath := filepath.Join(wd, "README.md")
licensePath := filepath.Join(wd, "LICENSE.txt")
require.NoError(t, os.WriteFile(exePath, []byte("AdGuardHome.exe"), 0o755))
require.NoError(t, os.WriteFile(yamlPath, []byte("AdGuardHome.yaml"), 0o644))
require.NoError(t, os.WriteFile(readmePath, []byte("README.md"), 0o644))
require.NoError(t, os.WriteFile(licensePath, []byte("LICENSE.txt"), 0o644))
// start server for returning package file
pkgData, err := os.ReadFile("testdata/AdGuardHome.zip")
require.NoError(t, err)
l, lport := startHTTPServer(string(pkgData))
testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{
Client: &http.Client{},
GOOS: "windows",
Version: "v0.103.0",
err = u.Update(false)
assert.NoError(t, err)
})
fakeURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", lport),
Path: "AdGuardHome.zip",
}
t.Run("api_fail", func(t *testing.T) {
srv.Close()
u.workDir = wd
u.confName = yamlPath
u.newVersion = "v0.103.1"
u.packageURL = fakeURL.String()
require.NoError(t, u.prepare(exePath))
require.NoError(t, u.downloadPackageFile())
require.NoError(t, u.unpack())
// assert.Nil(t, u.check())
require.NoError(t, u.backup(false))
require.NoError(t, u.replace())
u.clean()
// check backup files
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.exe"))
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.exe", string(d))
// check updated files
d, err = os.ReadFile(exePath)
require.NoError(t, err)
assert.Equal(t, "1", string(d))
d, err = os.ReadFile(readmePath)
require.NoError(t, err)
assert.Equal(t, "2", string(d))
d, err = os.ReadFile(licensePath)
require.NoError(t, err)
assert.Equal(t, "3", string(d))
d, err = os.ReadFile(yamlPath)
require.NoError(t, err)
assert.Equal(t, "AdGuardHome.yaml", string(d))
}
func TestUpdater_VersionInto_ARM(t *testing.T) {
const jsonData = `{
"version": "v0.103.0-beta.2",
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
"selfupdate_min_version": "v0.0",
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz"
}`
l, lport := startHTTPServer(jsonData)
testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{
Client: &http.Client{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
GOARM: "7",
err = u.Update(true)
var urlErr *url.Error
assert.ErrorAs(t, err, &urlErr)
})
fakeURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", lport),
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
}
u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
}
func TestUpdater_VersionInto_MIPS(t *testing.T) {
const jsonData = `{
"version": "v0.103.0-beta.2",
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
"selfupdate_min_version": "v0.0",
"download_linux_mips_softfloat": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
}`
l, lport := startHTTPServer(jsonData)
testutil.CleanupAndRequireSuccess(t, l.Close)
u := NewUpdater(&Config{
Client: &http.Client{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "mips",
GOOS: "linux",
GOMIPS: "softfloat",
})
fakeURL := &url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", lport),
Path: path.Join("adguardhome", version.ChannelBeta, "version.json"),
}
u.versionCheckURL = fakeURL.String()
info, err := u.VersionInfo(false)
require.NoError(t, err)
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
}

View File

@@ -12,9 +12,9 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
@@ -62,7 +62,7 @@ type Config struct {
CacheTTL time.Duration
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
MaxConnReadSize int64
MaxConnReadSize uint64
// MaxRedirects is the maximum redirects count.
MaxRedirects int
@@ -102,7 +102,7 @@ type Default struct {
cacheTTL time.Duration
// maxConnReadSize is an upper limit in bytes for reading from net.Conn.
maxConnReadSize int64
maxConnReadSize uint64
// maxRedirects is the maximum redirects count.
maxRedirects int
@@ -208,11 +208,7 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
}
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
r, err := aghio.LimitReader(conn, w.maxConnReadSize)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
r := ioutil.LimitReader(conn, w.maxConnReadSize)
_ = conn.SetDeadline(time.Now().Add(w.timeout))
_, err = io.WriteString(conn, target+"\r\n")