all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-07-26 13:18:44 +03:00
parent ec83d0eb86
commit 48ee2f8a42
99 changed files with 3202 additions and 1886 deletions

View File

@@ -1,10 +1,11 @@
package aghio
package aghio_test
import (
"io"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -31,7 +32,7 @@ func TestLimitReader(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReader(nil, tc.n)
_, err := aghio.LimitReader(nil, tc.n)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
@@ -57,7 +58,7 @@ func TestLimitedReader_Read(t *testing.T) {
limit: 3,
want: 0,
}, {
err: &LimitReachedError{
err: &aghio.LimitReachedError{
Limit: 0,
},
name: "limit_reached",
@@ -74,7 +75,7 @@ func TestLimitedReader_Read(t *testing.T) {
for _, tc := range testCases {
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
lreader, err := LimitReader(readCloser, tc.limit)
lreader, err := aghio.LimitReader(readCloser, tc.limit)
require.NoError(t, err)
require.NotNil(t, lreader)
@@ -89,7 +90,7 @@ func TestLimitedReader_Read(t *testing.T) {
}
func TestLimitedReader_LimitReachedError(t *testing.T) {
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
Limit: 0,
})
}

View File

@@ -141,9 +141,9 @@ type HostsRecord struct {
Canonical string
}
// equal returns true if all fields of rec are equal to field in other or they
// Equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
} else if other == nil {
@@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) {
}
// hc.last is nil on the first refresh, so let that one through.
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) {
log.Debug("%s: no changes detected", hostsContainerPrefix)
return nil

View File

@@ -0,0 +1,144 @@
package aghnet
import (
"io/fs"
"net/netip"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const nl = "\n"
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &fakefs.StatFS{
OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") },
OnStat: func(name string) (fi fs.FileInfo, err error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@@ -1,9 +1,7 @@
package aghnet
package aghnet_test
import (
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync/atomic"
@@ -12,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
@@ -24,10 +23,7 @@ import (
"github.com/stretchr/testify/require"
)
const (
nl = "\n"
sp = " "
)
const nl = "\n"
func TestNewHostsContainer(t *testing.T) {
const dirname = "dir"
@@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) {
name: "one_file",
paths: []string{p},
}, {
wantErr: ErrNoHostsPaths,
wantErr: aghnet.ErrNoHostsPaths,
name: "no_files",
paths: []string{},
}, {
wantErr: ErrNoHostsPaths,
wantErr: aghnet.ErrNoHostsPaths,
name: "non-existent_file",
paths: []string{path.Join(dirname, filename+"2")},
}, {
@@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) {
return eventsCh
}
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
OnEvents: onEvents,
OnAdd: onAdd,
OnClose: func() (err error) { return nil },
@@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
_, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
@@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, testFS, nil, p)
_, _ = aghnet.NewHostsContainer(0, testFS, nil, p)
})
})
@@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
@@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(0, testFS, w, "dir")
hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, want *HostsRecord) {
checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) {
t.Helper()
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
@@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) {
require.True(t, ok)
require.NotNil(t, rec)
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, &HostsRecord{
checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet(),
Canonical: "hostname",
})
@@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
checkRefresh(t, &HostsRecord{
checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet("alias"),
Canonical: "hostname",
})
@@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) {
})
}
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &aghtest.StatFS{
OnStat: func(name string) (fs.FileInfo, error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestHostsContainer_Translate(t *testing.T) {
stubWatcher := aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil },
@@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
@@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
@@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) {
})
}
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@@ -3,6 +3,7 @@ package aghnet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -15,6 +16,10 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// DialContextFunc is the semantic alias for dialing functions, such as
// [http.Transport.DialContext].
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.

View File

@@ -5,9 +5,9 @@ import (
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
)
@@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
Data: []byte(`nameserver 1.1.1.1`),
},
}
panicFsys := &aghtest.FS{
panicFsys := &fakefs.FS{
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
}

View File

@@ -0,0 +1,334 @@
package aghnet
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe"
"github.com/google/renameio/v2/maybe"
"golang.org/x/sys/unix"
)

View File

@@ -1,21 +1,11 @@
package aghnet
package aghnet_test
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
@@ -24,315 +14,3 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@@ -0,0 +1,52 @@
// Package aghrenameio is a wrapper around package github.com/google/renameio/v2
// that provides a similar stream-based API for both Unix and Windows systems.
// While the Windows API is not technically atomic, it still provides a
// consistent stream-based interface, and atomic renames of files do not seem to
// be possible in all cases anyway.
//
// See https://github.com/google/renameio/issues/1.
//
// TODO(a.garipov): Consider moving to golibs/renameioutil once tried and
// tested.
package aghrenameio
import (
"io/fs"
"github.com/AdguardTeam/golibs/errors"
)
// PendingFile is the interface for pending temporary files.
type PendingFile interface {
// Cleanup closes the file, and removes it without performing the renaming.
// To close and rename the file, use CloseReplace.
Cleanup() (err error)
// CloseReplace closes the temporary file and replaces the destination file
// with it, possibly atomically.
//
// This method is not safe for concurrent use by multiple goroutines.
CloseReplace() (err error)
// Write writes len(b) bytes from b to the File. It returns the number of
// bytes written and an error, if any. Write returns a non-nil error when n
// != len(b).
Write(b []byte) (n int, err error)
}
// NewPendingFile is a wrapper around [renameio.NewPendingFile] on Unix systems
// and [os.CreateTemp] on Windows.
func NewPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
return newPendingFile(filePath, mode)
}
// WithDeferredCleanup is a helper that performs the necessary cleanups and
// finalizations of the temporary files based on the returned error.
func WithDeferredCleanup(returned error, file PendingFile) (err error) {
// Make sure that any error returned from here is marked as a deferred one.
if returned != nil {
return errors.WithDeferred(returned, file.Cleanup())
}
return errors.WithDeferred(nil, file.CloseReplace())
}

View File

@@ -0,0 +1,101 @@
package aghrenameio_test
import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testPerm is the common permission mode for tests.
const testPerm fs.FileMode = 0o644
// Common file data for tests.
var (
initialData = []byte("initial data\n")
newData = []byte("new data\n")
)
func TestPendingFile(t *testing.T) {
t.Parallel()
targetPath := newInitialFile(t)
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
require.NoError(t, err)
_, err = f.Write(newData)
require.NoError(t, err)
err = f.CloseReplace()
require.NoError(t, err)
gotData, err := os.ReadFile(targetPath)
require.NoError(t, err)
assert.Equal(t, newData, gotData)
}
// newInitialFile is a test helper that returns the path to the file containing
// [initialData].
func newInitialFile(t *testing.T) (targetPath string) {
t.Helper()
dir := t.TempDir()
targetPath = filepath.Join(dir, "target")
err := os.WriteFile(targetPath, initialData, 0o644)
require.NoError(t, err)
return targetPath
}
func TestWithDeferredCleanup(t *testing.T) {
t.Parallel()
const testError errors.Error = "test error"
testCases := []struct {
error error
name string
wantErrMsg string
wantData []byte
}{{
name: "success",
error: nil,
wantErrMsg: "",
wantData: newData,
}, {
name: "error",
error: testError,
wantErrMsg: testError.Error(),
wantData: initialData,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
targetPath := newInitialFile(t)
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
require.NoError(t, err)
_, err = f.Write(newData)
require.NoError(t, err)
err = aghrenameio.WithDeferredCleanup(tc.error, f)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
gotData, err := os.ReadFile(targetPath)
require.NoError(t, err)
assert.Equal(t, tc.wantData, gotData)
})
}
}

View File

@@ -0,0 +1,48 @@
//go:build unix
package aghrenameio
import (
"io/fs"
"github.com/google/renameio/v2"
)
// pendingFile is a wrapper around [*renameio.PendingFile] making it an
// [io.WriteCloser].
type pendingFile struct {
file *renameio.PendingFile
}
// type check
var _ PendingFile = pendingFile{}
// Cleanup implements the [PendingFile] interface for pendingFile.
func (f pendingFile) Cleanup() (err error) {
return f.file.Cleanup()
}
// CloseReplace implements the [PendingFile] interface for pendingFile.
func (f pendingFile) CloseReplace() (err error) {
return f.file.CloseAtomicallyReplace()
}
// Write implements the [PendingFile] interface for pendingFile.
func (f pendingFile) Write(b []byte) (n int, err error) {
return f.file.Write(b)
}
// NewPendingFile is a wrapper around [renameio.NewPendingFile].
//
// f.Close must be called to finish the renaming.
func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
file, err := renameio.NewPendingFile(filePath, renameio.WithPermissions(mode))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
return pendingFile{
file: file,
}, nil
}

View File

@@ -0,0 +1,74 @@
//go:build windows
package aghrenameio
import (
"fmt"
"io/fs"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
)
// pendingFile is a wrapper around [*os.File] calling [os.Rename] in its Close
// method.
type pendingFile struct {
file *os.File
targetPath string
}
// type check
var _ PendingFile = (*pendingFile)(nil)
// Cleanup implements the [PendingFile] interface for *pendingFile.
func (f *pendingFile) Cleanup() (err error) {
closeErr := f.file.Close()
err = os.Remove(f.file.Name())
// Put closeErr into the deferred error because that's where it is usually
// expected.
return errors.WithDeferred(err, closeErr)
}
// CloseReplace implements the [PendingFile] interface for *pendingFile.
func (f *pendingFile) CloseReplace() (err error) {
err = f.file.Close()
if err != nil {
return fmt.Errorf("closing: %w", err)
}
err = os.Rename(f.file.Name(), f.targetPath)
if err != nil {
return fmt.Errorf("renaming: %w", err)
}
return nil
}
// Write implements the [PendingFile] interface for *pendingFile.
func (f *pendingFile) Write(b []byte) (n int, err error) {
return f.file.Write(b)
}
// NewPendingFile is a wrapper around [os.CreateTemp].
//
// f.Close must be called to finish the renaming.
func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
// Use the same directory as the file itself, because moves across
// filesystems can be especially problematic.
file, err := os.CreateTemp(filepath.Dir(filePath), "")
if err != nil {
return nil, fmt.Errorf("opening pending file: %w", err)
}
err = file.Chmod(mode)
if err != nil {
return nil, fmt.Errorf("preparing pending file: %w", err)
}
return &pendingFile{
file: file,
targetPath: filePath,
}, nil
}

View File

@@ -2,7 +2,9 @@
package aghtest
import (
"crypto/sha256"
"io"
"net"
"testing"
"github.com/AdguardTeam/golibs/log"
@@ -34,3 +36,10 @@ func ReplaceLogLevel(t testing.TB, l log.Level) {
t.Cleanup(func() { log.SetLevel(prev) })
log.SetLevel(l)
}
// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host.
func HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}

View File

@@ -2,11 +2,15 @@ package aghtest
import (
"context"
"io"
"io/fs"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
@@ -15,67 +19,6 @@ import (
//
// Keep entities in this file in alphabetic order.
// Standard Library
// Package fs
// FS is a fake [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// type check
var _ fs.FS = (*FS)(nil)
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = (*GlobFS)(nil)
// GlobFS is a fake [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = (*StatFS)(nil)
// StatFS is a fake [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// Package io
// Writer is a fake [io.Writer] implementation for tests.
type Writer struct {
OnWrite func(b []byte) (n int, err error)
}
var _ io.Writer = (*Writer)(nil)
// Write implements the [io.Writer] interface for *Writer.
func (w *Writer) Write(b []byte) (n int, err error) {
return w.OnWrite(b)
}
// Module adguard-home
// Package aghos
@@ -135,6 +78,71 @@ func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
return s.OnConfig()
}
// Package client
// AddressProcessor is a fake [client.AddressProcessor] implementation for
// tests.
type AddressProcessor struct {
OnProcess func(ip netip.Addr)
OnClose func() (err error)
}
// type check
var _ client.AddressProcessor = (*AddressProcessor)(nil)
// Process implements the [client.AddressProcessor] interface for
// *AddressProcessor.
func (p *AddressProcessor) Process(ip netip.Addr) {
p.OnProcess(ip)
}
// Close implements the [client.AddressProcessor] interface for
// *AddressProcessor.
func (p *AddressProcessor) Close() (err error) {
return p.OnClose()
}
// AddressUpdater is a fake [client.AddressUpdater] implementation for tests.
type AddressUpdater struct {
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
}
// type check
var _ client.AddressUpdater = (*AddressUpdater)(nil)
// UpdateAddress implements the [client.AddressUpdater] interface for
// *AddressUpdater.
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
p.OnUpdateAddress(ip, host, info)
}
// Package filtering
// Resolver is a fake [filtering.Resolver] implementation for tests.
type Resolver struct {
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
}
// LookupIP implements the [filtering.Resolver] interface for *Resolver.
func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) {
return r.OnLookupIP(ctx, network, host)
}
// Package rdns
// Exchanger is a fake [rdns.Exchanger] implementation for tests.
type Exchanger struct {
OnExchange func(ip netip.Addr) (host string, ttl time.Duration, err error)
}
// type check
var _ rdns.Exchanger = (*Exchanger)(nil)
// Exchange implements [rdns.Exchanger] interface for *Exchanger.
func (e *Exchanger) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
return e.OnExchange(ip)
}
// Module dnsproxy
// Package upstream

View File

@@ -1,3 +1,11 @@
package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
// Put interface checks that cause import cycles here.
// type check
var _ filtering.Resolver = (*aghtest.Resolver)(nil)

View File

@@ -1,57 +0,0 @@
package aghtest
import (
"context"
"crypto/sha256"
"net"
"sync"
)
// TestResolver is a Resolver for tests.
type TestResolver struct {
counter int
counterLock sync.Mutex
}
// HostToIPs generates IPv4 and IPv6 from host.
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}
// LookupIP implements Resolver interface for *testResolver. It returns the
// slice of net.IP with IPv4 and IPv6 instances.
func (r *TestResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
ipv4, ipv6 := r.HostToIPs(host)
addrs := []net.IP{ipv4, ipv6}
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver. It returns the
// slice of IPv4 and IPv6 instances converted to strings.
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
ipv4, ipv6 := r.HostToIPs(host)
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return []string{
ipv4.String(),
ipv6.String(),
}, nil
}
// Counter returns the number of requests handled.
func (r *TestResolver) Counter() int {
r.counterLock.Lock()
defer r.counterLock.Unlock()
return r.counter
}

294
internal/client/addrproc.go Normal file
View File

@@ -0,0 +1,294 @@
package client
import (
"context"
"net/netip"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// ErrClosed is returned from [AddressProcessor.Close] if it's closed more than
// once.
const ErrClosed errors.Error = "use of closed address processor"
// AddressProcessor is the interface for types that can process clients.
type AddressProcessor interface {
Process(ip netip.Addr)
Close() (err error)
}
// EmptyAddrProc is an [AddressProcessor] that does nothing.
type EmptyAddrProc struct{}
// type check
var _ AddressProcessor = EmptyAddrProc{}
// Process implements the [AddressProcessor] interface for EmptyAddrProc.
func (EmptyAddrProc) Process(_ netip.Addr) {}
// Close implements the [AddressProcessor] interface for EmptyAddrProc.
func (EmptyAddrProc) Close() (_ error) { return nil }
// DefaultAddrProcConfig is the configuration structure for address processors.
type DefaultAddrProcConfig struct {
// DialContext is used to create TCP connections to WHOIS servers.
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
DialContext aghnet.DialContextFunc
// Exchanger is used to perform rDNS queries. Exchanger must not be nil if
// [DefaultAddrProcConfig.UseRDNS] is true.
Exchanger rdns.Exchanger
// PrivateSubnets are used to determine if an incoming IP address is
// private. It must not be nil.
PrivateSubnets netutil.SubnetSet
// AddressUpdater is used to update the information about a client's IP
// address. It must not be nil.
AddressUpdater AddressUpdater
// InitialAddresses are the addresses that are queued for processing
// immediately by [NewDefaultAddrProc].
InitialAddresses []netip.Addr
// UseRDNS, if true, enables resolving of client IP addresses using reverse
// DNS.
UseRDNS bool
// UsePrivateRDNS, if true, enables resolving of private client IP addresses
// using reverse DNS. See [DefaultAddrProcConfig.PrivateSubnets].
UsePrivateRDNS bool
// UseWHOIS, if true, enables resolving of client IP addresses using WHOIS.
UseWHOIS bool
}
// AddressUpdater is the interface for storages of DNS clients that can update
// information about them.
//
// TODO(a.garipov): Consider using the actual client storage once it is moved
// into this package.
type AddressUpdater interface {
// UpdateAddress updates information about an IP address, setting host (if
// not empty) and WHOIS information (if not nil).
UpdateAddress(ip netip.Addr, host string, info *whois.Info)
}
// DefaultAddrProc processes incoming client addresses with rDNS and WHOIS, if
// configured, and updates that information in a client storage.
type DefaultAddrProc struct {
// clientIPsMu serializes closure of clientIPs and access to isClosed.
clientIPsMu *sync.Mutex
// clientIPs is the channel queueing client processing tasks.
clientIPs chan netip.Addr
// rdns is used to perform rDNS lookups of clients' IP addresses.
rdns rdns.Interface
// whois is used to perform WHOIS lookups of clients' IP addresses.
whois whois.Interface
// addrUpdater is used to update the information about a client's IP
// address.
addrUpdater AddressUpdater
// privateSubnets are used to determine if an incoming IP address is
// private.
privateSubnets netutil.SubnetSet
// isClosed is set to true once the address processor is closed.
isClosed bool
// usePrivateRDNS, if true, enables resolving of private client IP addresses
// using reverse DNS.
usePrivateRDNS bool
}
const (
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
// processing.
defaultQueueSize = 255
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
// processing. It must be greater than zero.
defaultCacheSize = 10_000
// defaultIPTTL is the Time to Live duration for IP addresses cached by
// rDNS and WHOIS.
defaultIPTTL = 1 * time.Hour
)
// NewDefaultAddrProc returns a new running client address processor. c must
// not be nil.
func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
p = &DefaultAddrProc{
clientIPsMu: &sync.Mutex{},
clientIPs: make(chan netip.Addr, defaultQueueSize),
rdns: &rdns.Empty{},
addrUpdater: c.AddressUpdater,
whois: &whois.Empty{},
privateSubnets: c.PrivateSubnets,
usePrivateRDNS: c.UsePrivateRDNS,
}
if c.UseRDNS {
p.rdns = rdns.New(&rdns.Config{
Exchanger: c.Exchanger,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
}
if c.UseWHOIS {
p.whois = newWHOIS(c.DialContext)
}
go p.process()
for _, ip := range c.InitialAddresses {
p.Process(ip)
}
return p
}
// newWHOIS returns a whois.Interface instance using the given function for
// dialing.
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
// TODO(s.chzhen): Consider making configurable.
const (
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultMaxConnReadSize is an upper limit in bytes for reading from a
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
)
return whois.New(&whois.Config{
DialContext: dialFunc,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
}
// type check
var _ AddressProcessor = (*DefaultAddrProc)(nil)
// Process implements the [AddressProcessor] interface for *DefaultAddrProc.
func (p *DefaultAddrProc) Process(ip netip.Addr) {
p.clientIPsMu.Lock()
defer p.clientIPsMu.Unlock()
if p.isClosed {
return
}
select {
case p.clientIPs <- ip:
// Go on.
default:
log.Debug("clients: ip channel is full; len: %d", len(p.clientIPs))
}
}
// process processes the incoming client IP-address information. It is intended
// to be used as a goroutine. Once clientIPs is closed, process exits.
func (p *DefaultAddrProc) process() {
defer log.OnPanic("addrProcessor.process")
log.Info("clients: processing addresses")
for ip := range p.clientIPs {
host := p.processRDNS(ip)
info := p.processWHOIS(ip)
p.addrUpdater.UpdateAddress(ip, host, info)
}
log.Info("clients: finished processing addresses")
}
// processRDNS resolves the clients' IP addresses using reverse DNS. host is
// empty if there were errors or if the information hasn't changed.
func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
start := time.Now()
log.Debug("clients: processing %s with rdns", ip)
defer func() {
log.Debug("clients: finished processing %s with rdns in %s", ip, time.Since(start))
}()
ok := p.shouldResolve(ip)
if !ok {
return
}
host, changed := p.rdns.Process(ip)
if !changed {
host = ""
}
return host
}
// shouldResolve returns false if ip is a loopback address, or ip is private and
// resolving of private addresses is disabled.
func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
return !ip.IsLoopback() &&
(p.usePrivateRDNS || !p.privateSubnets.Contains(ip.AsSlice()))
}
// processWHOIS looks up the information about clients' IP addresses in the
// WHOIS databases. info is nil if there were errors or if the information
// hasn't changed.
func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
start := time.Now()
log.Debug("clients: processing %s with whois", ip)
defer func() {
log.Debug("clients: finished processing %s with whois in %s", ip, time.Since(start))
}()
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
// context.
info, changed := p.whois.Process(context.Background(), ip)
if !changed {
info = nil
}
return info
}
// Close implements the [AddressProcessor] interface for *DefaultAddrProc.
func (p *DefaultAddrProc) Close() (err error) {
p.clientIPsMu.Lock()
defer p.clientIPsMu.Unlock()
if p.isClosed {
return ErrClosed
}
close(p.clientIPs)
p.isClosed = true
return nil
}

View File

@@ -0,0 +1,259 @@
package client_test
import (
"context"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
)
func TestEmptyAddrProc(t *testing.T) {
t.Parallel()
p := client.EmptyAddrProc{}
assert.NotPanics(t, func() {
p.Process(testIP)
})
assert.NotPanics(t, func() {
err := p.Close()
assert.NoError(t, err)
})
}
func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
t.Parallel()
privateIP := netip.MustParseAddr("192.168.0.1")
testCases := []struct {
rdnsErr error
ip netip.Addr
name string
host string
usePrivate bool
wantUpd bool
}{{
rdnsErr: nil,
ip: testIP,
name: "success",
host: testHost,
usePrivate: false,
wantUpd: true,
}, {
rdnsErr: nil,
ip: testIP,
name: "no_host",
host: "",
usePrivate: false,
wantUpd: false,
}, {
rdnsErr: nil,
ip: netip.MustParseAddr("127.0.0.1"),
name: "localhost",
host: "",
usePrivate: false,
wantUpd: false,
}, {
rdnsErr: nil,
ip: privateIP,
name: "private_ignored",
host: "",
usePrivate: false,
wantUpd: false,
}, {
rdnsErr: nil,
ip: privateIP,
name: "private_processed",
host: "private.example",
usePrivate: true,
wantUpd: true,
}, {
rdnsErr: errors.Error("rdns error"),
ip: testIP,
name: "rdns_error",
host: "",
usePrivate: false,
wantUpd: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
updIPCh := make(chan netip.Addr, 1)
updHostCh := make(chan string, 1)
updInfoCh := make(chan *whois.Info, 1)
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
panic("not implemented")
},
Exchanger: &aghtest.Exchanger{
OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) {
return tc.host, 0, tc.rdnsErr
},
},
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
AddressUpdater: &aghtest.AddressUpdater{
OnUpdateAddress: newOnUpdateAddress(tc.wantUpd, updIPCh, updHostCh, updInfoCh),
},
UseRDNS: true,
UsePrivateRDNS: tc.usePrivate,
UseWHOIS: false,
})
testutil.CleanupAndRequireSuccess(t, p.Close)
p.Process(tc.ip)
if !tc.wantUpd {
return
}
gotIP, _ := testutil.RequireReceive(t, updIPCh, testTimeout)
assert.Equal(t, tc.ip, gotIP)
gotHost, _ := testutil.RequireReceive(t, updHostCh, testTimeout)
assert.Equal(t, tc.host, gotHost)
gotInfo, _ := testutil.RequireReceive(t, updInfoCh, testTimeout)
assert.Nil(t, gotInfo)
})
}
}
// newOnUpdateAddress is a test helper that returns a new OnUpdateAddress
// callback using the provided channels if an update is expected and panicking
// otherwise.
func newOnUpdateAddress(
want bool,
ips chan<- netip.Addr,
hosts chan<- string,
infos chan<- *whois.Info,
) (f func(ip netip.Addr, host string, info *whois.Info)) {
return func(ip netip.Addr, host string, info *whois.Info) {
if !want {
panic("got unexpected update")
}
ips <- ip
hosts <- host
infos <- info
}
}
func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
t.Parallel()
testCases := []struct {
wantInfo *whois.Info
exchErr error
name string
wantUpd bool
}{{
wantInfo: &whois.Info{
City: testWHOISCity,
},
exchErr: nil,
name: "success",
wantUpd: true,
}, {
wantInfo: nil,
exchErr: nil,
name: "no_info",
wantUpd: false,
}, {
wantInfo: nil,
exchErr: errors.Error("whois error"),
name: "whois_error",
wantUpd: false,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
whoisConn := &fakenet.Conn{
OnClose: func() (err error) { return nil },
OnRead: func(b []byte) (n int, err error) {
if tc.wantInfo == nil {
return 0, tc.exchErr
}
data := "city: " + tc.wantInfo.City + "\n"
copy(b, data)
return len(data), io.EOF
},
OnSetDeadline: func(_ time.Time) (err error) { return nil },
OnWrite: func(b []byte) (n int, err error) { return len(b), nil },
}
updIPCh := make(chan netip.Addr, 1)
updHostCh := make(chan string, 1)
updInfoCh := make(chan *whois.Info, 1)
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
return whoisConn, nil
},
Exchanger: &aghtest.Exchanger{
OnExchange: func(_ netip.Addr) (_ string, _ time.Duration, _ error) {
panic("not implemented")
},
},
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
AddressUpdater: &aghtest.AddressUpdater{
OnUpdateAddress: newOnUpdateAddress(tc.wantUpd, updIPCh, updHostCh, updInfoCh),
},
UseRDNS: false,
UsePrivateRDNS: false,
UseWHOIS: true,
})
testutil.CleanupAndRequireSuccess(t, p.Close)
p.Process(testIP)
if !tc.wantUpd {
return
}
gotIP, _ := testutil.RequireReceive(t, updIPCh, testTimeout)
assert.Equal(t, testIP, gotIP)
gotHost, _ := testutil.RequireReceive(t, updHostCh, testTimeout)
assert.Empty(t, gotHost)
gotInfo, _ := testutil.RequireReceive(t, updInfoCh, testTimeout)
assert.Equal(t, tc.wantInfo, gotInfo)
})
}
}
func TestDefaultAddrProc_Close(t *testing.T) {
t.Parallel()
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{})
err := p.Close()
assert.NoError(t, err)
err = p.Close()
assert.ErrorIs(t, err, client.ErrClosed)
}

View File

@@ -0,0 +1,5 @@
// Package client contains types and logic dealing with AdGuard Home's DNS
// clients.
//
// TODO(a.garipov): Expand.
package client

View File

@@ -0,0 +1,25 @@
package client_test
import (
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/golibs/testutil"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testHost is the common hostname for tests.
const testHost = "client.example"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testWHOISCity is the common city for tests.
const testWHOISCity = "Brussels"
// testIP is the common IP address for tests.
var testIP = netip.MustParseAddr("1.2.3.4")

View File

@@ -9,7 +9,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe"
"github.com/google/renameio/v2/maybe"
"golang.org/x/exp/slices"
)

View File

@@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
@@ -270,7 +271,13 @@ type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address
TCPListenAddrs []*net.TCPAddr // TCP listen address
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
OnDNSRequest func(d *proxy.DNSContext)
// AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used.
//
// TODO(a.garipov): The use of [client.EmptyAddrProc] is a crutch for tests.
// Remove that.
AddrProcConf *client.DefaultAddrProcConfig
FilteringConfig
TLSConfig
@@ -298,9 +305,6 @@ type ServerConfig struct {
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64.
DNS64Prefixes []netip.Prefix
// ResolveClients signals if the RDNS should resolve clients' addresses.
ResolveClients bool
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
UsePrivateRDNS bool
@@ -340,6 +344,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
UseDNS64: srvConf.UseDNS64,

View File

@@ -0,0 +1,57 @@
package dnsforward
import (
"context"
"fmt"
"net"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("dnsforward: dialing %q for network %q", addr, network)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
// TODO(a.garipov): Consider making configurable.
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
return dialer.DialContext(ctx, network, addr)
}
addrs, err := s.Resolve(host)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
}
log.Debug("dnsforward: resolving %q: %v", host, addrs)
if len(addrs) == 0 {
return nil, fmt.Errorf("no addresses for host %q", host)
}
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
conn, err = dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
return conn, err
}
// TODO(a.garipov): Use errors.Join in Go 1.20.
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
@@ -99,8 +100,17 @@ type Server struct {
// must be a valid domain name plus dots on each side.
localDomainSuffix string
ipset ipsetCtx
privateNets netutil.SubnetSet
ipset ipsetCtx
privateNets netutil.SubnetSet
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
// WHOIS, etc.
addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers
@@ -170,6 +180,9 @@ const (
// NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once
//
// TODO(a.garipov): How many constructors and initializers does this thing have?
// Refactor!
func NewServer(p DNSCreateParams) (s *Server, err error) {
var localDomainSuffix string
if p.LocalDomain == "" {
@@ -257,14 +270,25 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) {
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
}
// RDNSSettings returns the copy of actual RDNS configuration.
func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) {
// LocalPTRResolvers returns the current local PTR resolver configuration.
func (s *Server) LocalPTRResolvers() (localPTRResolvers []string) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
return stringutil.CloneSlice(s.conf.LocalPTRResolvers),
s.conf.ResolveClients,
s.conf.UsePrivateRDNS
return stringutil.CloneSlice(s.conf.LocalPTRResolvers)
}
// AddrProcConfig returns the current address processing configuration. Only
// fields c.UsePrivateRDNS, c.UseRDNS, and c.UseWHOIS are filled.
func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
return &client.DefaultAddrProcConfig{
UsePrivateRDNS: s.conf.UsePrivateRDNS,
UseRDNS: s.conf.AddrProcConf.UseRDNS,
UseWHOIS: s.conf.AddrProcConf.UseWHOIS,
}
}
// Resolve - get IP addresses by host name from an upstream server.
@@ -292,17 +316,13 @@ const (
var _ rdns.Exchanger = (*Server)(nil)
// Exchange implements the [rdns.Exchanger] interface for *Server.
func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if !s.conf.ResolveClients {
return "", nil
}
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
return "", 0, fmt.Errorf("reversing ip: %w", err)
}
arpa = dns.Fqdn(arpa)
@@ -318,16 +338,17 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
Qclass: dns.ClassINET,
}},
}
ctx := &proxy.DNSContext{
dctx := &proxy.DNSContext{
Proto: "udp",
Req: req,
StartTime: time.Now(),
}
var resolver *proxy.Proxy
if s.isPrivateIP(ip) {
if s.privateNets.Contains(ip.AsSlice()) {
if !s.conf.UsePrivateRDNS {
return "", nil
return "", 0, nil
}
resolver = s.localResolvers
@@ -336,53 +357,48 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
resolver = s.internalProxy
}
if err = resolver.Resolve(ctx); err != nil {
return "", err
if err = resolver.Resolve(dctx); err != nil {
return "", 0, err
}
return hostFromPTR(ctx.Res)
return hostFromPTR(dctx.Res)
}
// hostFromPTR returns domain name from the PTR response or error.
func hostFromPTR(resp *dns.Msg) (host string, err error) {
func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
// Distinguish between NODATA response and a failed request.
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
return "", fmt.Errorf(
return "", 0, fmt.Errorf(
"received %s response: %w",
dns.RcodeToString[resp.Rcode],
ErrRDNSFailed,
)
}
var ttlSec uint32
for _, ans := range resp.Answer {
ptr, ok := ans.(*dns.PTR)
if ok {
return strings.TrimSuffix(ptr.Ptr, "."), nil
if !ok {
continue
}
if ptr.Hdr.Ttl > ttlSec {
host = ptr.Ptr
ttlSec = ptr.Hdr.Ttl
}
}
return "", ErrRDNSNoData
}
if host != "" {
// NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter
// case.
host = strings.TrimSuffix(host, ".")
ttl = time.Duration(ttlSec) * time.Second
// isPrivateIP returns true if the ip is private.
func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) {
return s.privateNets.Contains(ip.AsSlice())
}
// ShouldResolveClient returns false if ip is a loopback address, or ip is
// private and resolving of private addresses is disabled.
func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) {
if ip.IsLoopback() {
return false
return host, ttl, nil
}
isPrivate := s.isPrivateIP(ip)
s.serverLock.RLock()
defer s.serverLock.RUnlock()
return s.conf.ResolveClients &&
(s.conf.UsePrivateRDNS || !isPrivate)
return "", 0, ErrRDNSNoData
}
// Start starts the DNS server.
@@ -457,23 +473,27 @@ func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
}
// setupResolvers initializes the resolvers for local addresses. For internal
// use only.
func (s *Server) setupResolvers(localAddrs []string) (err error) {
// setupLocalResolvers initializes the resolvers for local addresses. For
// internal use only.
func (s *Server) setupLocalResolvers() (err error) {
bootstraps := s.conf.BootstrapDNS
if len(localAddrs) == 0 {
localAddrs = s.sysResolvers.Get()
resolvers := s.conf.LocalPTRResolvers
if len(resolvers) == 0 {
resolvers = s.sysResolvers.Get()
bootstraps = nil
} else {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
}
localAddrs, err = s.filterOurDNSAddrs(localAddrs)
resolvers, err = s.filterOurDNSAddrs(resolvers)
if err != nil {
return err
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
@@ -486,10 +506,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: upsConfig,
UpstreamConfig: uc,
},
}
if s.conf.UsePrivateRDNS &&
// Only set the upstream config if there are any upstreams. It's safe
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
}
return nil
}
@@ -539,25 +566,48 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err)
}
s.registerHandlers()
// Set the proxy here because [setupLocalResolvers] sets its values.
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
err = s.setupResolvers(s.conf.LocalPTRResolvers)
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
err = s.setupLocalResolvers()
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
if s.conf.UsePrivateRDNS {
proxyConfig.PrivateRDNSUpstreamConfig = s.localResolvers.UpstreamConfig
}
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
s.recDetector.clear()
s.setupAddrProc()
s.registerHandlers()
return nil
}
// setupAddrProc initializes the address processor. For internal use only.
func (s *Server) setupAddrProc() {
// TODO(a.garipov): This is a crutch for tests; remove.
if s.conf.AddrProcConf == nil {
s.conf.AddrProcConf = &client.DefaultAddrProcConfig{}
}
if s.conf.AddrProcConf.AddressUpdater == nil {
s.addrProc = client.EmptyAddrProc{}
} else {
c := s.conf.AddrProcConf
c.DialContext = s.DialContext
c.PrivateSubnets = s.privateNets
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
s.addrProc = client.NewDefaultAddrProc(s.conf.AddrProcConf)
// Clear the initial addresses to not resolve them again.
//
// TODO(a.garipov): Consider ways of removing this once more client
// logic is moved to package client.
c.InitialAddresses = nil
}
}
// validateBlockingMode returns an error if the blocking mode data aren't valid.
func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) {
switch mode {
@@ -696,6 +746,11 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
if conf == nil {
conf = &s.conf
} else {
closeErr := s.addrProc.Close()
if closeErr != nil {
log.Error("dnsforward: closing address processor: %s", closeErr)
}
}
err = s.Prepare(conf)

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
@@ -39,11 +40,29 @@ func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
//
// TODO(a.garipov): Use more.
const testTimeout = 1 * time.Second
// testQuestionTarget is the common question target for tests.
//
// TODO(a.garipov): Use more.
const testQuestionTarget = "target.example"
const (
tlsServerName = "testdns.adguard.com"
testMessagesCount = 10
)
// testClientAddr is the common net.Addr for tests.
//
// TODO(a.garipov): Use more.
var testClientAddr net.Addr = &net.TCPAddr{
IP: net.IP{1, 2, 3, 4},
Port: 12345,
}
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
@@ -53,6 +72,13 @@ func startDeferStop(t *testing.T, s *Server) {
testutil.CleanupAndRequireSuccess(t, s.Stop)
}
// packageUpstreamVariableMu is used to serialize access to the package-level
// variables of package upstream.
//
// TODO(s.chzhen): Move these parameters to upstream options and remove this
// crutch.
var packageUpstreamVariableMu = &sync.Mutex{}
func createTestServer(
t *testing.T,
filterConf *filtering.Config,
@@ -61,6 +87,9 @@ func createTestServer(
) (s *Server) {
t.Helper()
packageUpstreamVariableMu.Lock()
defer packageUpstreamVariableMu.Unlock()
rules := `||nxdomain.example.org
||NULL.example.org^
127.0.0.1 host.example.org
@@ -307,11 +336,9 @@ func TestServer(t *testing.T) {
}
func TestServer_timeout(t *testing.T) {
const timeout time.Duration = time.Second
t.Run("custom", func(t *testing.T) {
srvConf := &ServerConfig{
UpstreamTimeout: timeout,
UpstreamTimeout: testTimeout,
FilteringConfig: FilteringConfig{
BlockingMode: BlockingModeDefault,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
@@ -324,7 +351,7 @@ func TestServer_timeout(t *testing.T) {
err = s.Prepare(srvConf)
require.NoError(t, err)
assert.Equal(t, timeout, s.conf.UpstreamTimeout)
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
})
t.Run("default", func(t *testing.T) {
@@ -441,7 +468,14 @@ func TestServerRace(t *testing.T) {
}
func TestSafeSearch(t *testing.T) {
resolver := &aghtest.TestResolver{}
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
safeSearchConf := filtering.SafeSearchConfig{
Enabled: true,
Google: true,
@@ -480,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct {
host string
@@ -545,7 +579,7 @@ func TestInvalidRequest(t *testing.T) {
// Send a DNS request without question.
_, _, err := (&dns.Client{
Timeout: 500 * time.Millisecond,
Timeout: testTimeout,
}).Exchange(&req, addr)
assert.NoErrorf(t, err, "got a response to an invalid query")
@@ -928,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
Upstream: aghtest.NewBlockUpstream(hostname, true),
})
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
ans4, _ := aghtest.HostToIPs(hostname)
filterConf := &filtering.Config{
SafeBrowsingEnabled: true,
@@ -1266,25 +1300,57 @@ func TestNewServer(t *testing.T) {
}
}
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
// copy of first answer record with doubled TTL.
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
if msg == nil {
return nil
}
if len(msg.Answer) == 0 {
return msg
}
rec := msg.Answer[0]
ptr, ok := rec.(*dns.PTR)
if !ok {
return msg
}
clone := *ptr
clone.Hdr.Ttl *= 2
msg.Answer = append(msg.Answer, &clone)
return msg
}
func TestServer_Exchange(t *testing.T) {
const (
onesHost = "one.one.one.one"
twosHost = "two.two.two.two"
localDomainHost = "local.domain"
defaultTTL = time.Second * 60
)
var (
onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1")
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
require.NoError(t, err)
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
@@ -1320,53 +1386,65 @@ func TestServer_Exchange(t *testing.T) {
},
}
srv.conf.ResolveClients = true
srv.conf.UsePrivateRDNS = true
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
want string
req netip.Addr
wantErr error
locUpstream upstream.Upstream
req netip.Addr
name string
want string
wantTTL time.Duration
}{{
name: "external_good",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: onesIP,
wantTTL: defaultTTL,
}, {
name: "local_good",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpstream,
req: localIP,
wantTTL: defaultTTL,
}, {
name: "upstream_error",
want: "",
wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "empty_answer_error",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: locUpstream,
req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0,
}, {
name: "invalid_answer",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: nonPtrUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "refused",
want: "",
wantErr: ErrRDNSFailed,
locUpstream: refusingUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "longest_ttl",
want: twosHost,
wantErr: nil,
locUpstream: nil,
req: twosIP,
wantTTL: defaultTTL * 2,
}}
for _, tc := range testCases {
@@ -1380,73 +1458,20 @@ func TestServer_Exchange(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
host, eerr := srv.Exchange(tc.req)
host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr)
assert.Equal(t, tc.want, host)
assert.Equal(t, tc.wantTTL, ttl)
})
}
t.Run("resolving_disabled", func(t *testing.T) {
srv.conf.UsePrivateRDNS = false
host, eerr := srv.Exchange(localIP)
host, _, eerr := srv.Exchange(localIP)
require.NoError(t, eerr)
assert.Empty(t, host)
})
}
func TestServer_ShouldResolveClient(t *testing.T) {
srv := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
ip netip.Addr
want require.BoolAssertionFunc
name string
resolve bool
usePrivate bool
}{{
name: "default",
ip: netip.MustParseAddr("1.1.1.1"),
want: require.True,
resolve: true,
usePrivate: true,
}, {
name: "no_rdns",
ip: netip.MustParseAddr("1.1.1.1"),
want: require.False,
resolve: false,
usePrivate: true,
}, {
name: "loopback",
ip: netip.MustParseAddr("127.0.0.1"),
want: require.False,
resolve: true,
usePrivate: true,
}, {
name: "private_resolve",
ip: netip.MustParseAddr("192.168.0.1"),
want: require.True,
resolve: true,
usePrivate: true,
}, {
name: "private_no_resolve",
ip: netip.MustParseAddr("192.168.0.1"),
want: require.False,
resolve: true,
usePrivate: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
srv.conf.ResolveClients = tc.resolve
srv.conf.UsePrivateRDNS = tc.usePrivate
ok := srv.ShouldResolveClient(tc.ip)
tc.want(t, ok)
})
}
}

View File

@@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler(
return true, nil
}
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from dctx.
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.Settings()
// clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
setts = s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)

View File

@@ -124,7 +124,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
cacheMinTTL := s.conf.CacheMinTTL
cacheMaxTTL := s.conf.CacheMaxTTL
cacheOptimistic := s.conf.CacheOptimistic
resolveClients := s.conf.ResolveClients
resolveClients := s.conf.AddrProcConf.UseRDNS
usePrivateRDNS := s.conf.UsePrivateRDNS
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
@@ -314,8 +314,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
setIfNotNil(&s.conf.ResolveClients, dc.ResolveClients)
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS)
return s.setConfigRestartable(dc)
}
@@ -335,6 +333,9 @@ func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {
// setConfigRestartable sets the parameters which trigger a restart.
// shouldRestart is true if the server should be restarted to apply changes.
// s.serverLock is expected to be locked.
//
// TODO(a.garipov): Some of these could probably be updated without a restart.
// Inspect and consider refactoring.
func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
for _, hasSet := range []bool{
setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams),
@@ -347,6 +348,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
} {
shouldRestart = shouldRestart || hasSet
if shouldRestart {

View File

@@ -30,6 +30,7 @@ type dnsContext struct {
setts *filtering.Settings
result *filtering.Result
// origResp is the response received from upstream. It is set when the
// response is modified by filters.
origResp *dns.Msg
@@ -48,13 +49,13 @@ type dnsContext struct {
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
clientID string
// startTime is the time at which the processing of the request has started.
startTime time.Time
// origQuestion is the question received from the client. It is set
// when the request is modified by rewrites.
origQuestion dns.Question
// startTime is the time at which the processing of the request has started.
startTime time.Time
// protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready.
protectionEnabled bool
@@ -160,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server.
//
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
const mozillaFQDN = "use-application-dns.net."
// healthcheckFQDN is a reserved domain-name used for healthchecking.
//
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
// grant requests to register test names in the normal way to any person or
// entity, making domain names under the .test TLD free to use in internal
// purposes.
//
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
const healthcheckFQDN = "healthcheck.adguardhome.test."
// processInitial terminates the following processing for some requests if
// needed and enriches dctx with some client-specific information.
//
@@ -169,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
defer log.Debug("dnsforward: finished processing initial")
pctx := dctx.proxyCtx
s.processClientIP(pctx.Addr)
q := pctx.Req.Question[0]
qt := q.Qtype
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
@@ -177,28 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
if s.conf.OnDNSRequest != nil {
s.conf.OnDNSRequest(pctx)
}
// Disable Mozilla DoH.
//
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." {
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
// Handle a reserved domain healthcheck.adguardhome.test.
//
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
// grant requests to register test names in the normal way to any person or
// entity, making domain names under test. TLD free to use in internal
// purposes.
//
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
if q.Name == "healthcheck.adguardhome.test." {
if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req)
@@ -213,11 +217,28 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
// Get the client-specific filtering settings.
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
dctx.setts = s.getClientRequestFilteringSettings(dctx)
dctx.setts = s.clientRequestFilteringSettings(dctx)
return resultCodeSuccess
}
// processClientIP sends the client IP address to s.addrProc, if needed.
func (s *Server) processClientIP(addr net.Addr) {
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
if clientIP == (netip.Addr{}) {
log.Info("dnsforward: warning: bad client addr %q", addr)
return
}
// Do not assign s.addrProc to a local variable to then use, since this lock
// also serializes the closure of s.addrProc.
s.serverLock.RLock()
defer s.serverLock.RUnlock()
s.addrProc.Process(clientIP)
}
func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIPLock.Lock()
defer s.tableHostToIPLock.Unlock()
@@ -698,6 +719,18 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
// Generate the server failure if the private upstream configuration
// is empty.
//
// TODO(e.burkov): Get rid of this crutch once the local resolvers
// logic is moved to the dnsproxy completely.
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -22,6 +23,96 @@ const (
ddrTestFQDN = ddrTestDomainName + "."
)
func TestServer_ProcessInitial(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
target string
wantRCode rules.RCode
qType rules.RRType
aaaaDisabled bool
wantRC resultCode
}{{
name: "success",
target: testQuestionTarget,
wantRCode: -1,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeSuccess,
}, {
name: "aaaa_disabled",
target: testQuestionTarget,
wantRCode: dns.RcodeSuccess,
qType: dns.TypeAAAA,
aaaaDisabled: true,
wantRC: resultCodeFinish,
}, {
name: "aaaa_disabled_a",
target: testQuestionTarget,
wantRCode: -1,
qType: dns.TypeA,
aaaaDisabled: true,
wantRC: resultCodeSuccess,
}, {
name: "mozilla_canary",
target: mozillaFQDN,
wantRCode: dns.RcodeNameError,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeFinish,
}, {
name: "adguardhome_healthcheck",
target: healthcheckFQDN,
wantRCode: dns.RcodeSuccess,
qType: dns.TypeA,
aaaaDisabled: false,
wantRC: resultCodeFinish,
}}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
c := ServerConfig{
FilteringConfig: FilteringConfig{
AAAADisabled: tc.aaaaDisabled,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
s := createTestServer(t, &filtering.Config{}, c, nil)
var gotAddr netip.Addr
s.addrProc = &aghtest.AddressProcessor{
OnProcess: func(ip netip.Addr) { gotAddr = ip },
OnClose: func() (err error) { panic("not implemented") },
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: createTestMessageWithType(tc.target, tc.qType),
Addr: testClientAddr,
RequestID: 1234,
},
}
gotRC := s.processInitial(dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
if tc.wantRCode > 0 {
gotResp := dctx.proxyCtx.Res
require.NotNil(t, gotResp)
assert.Equal(t, tc.wantRCode, gotResp.Rcode)
}
})
}
}
func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{
Priority: 1,
@@ -64,7 +155,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
}{{
name: "pass_host",
wantRes: resultCodeSuccess,
host: "example.net.",
host: testQuestionTarget,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8043,
@@ -234,33 +325,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
knownIP := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
wantIP netip.Addr
name string
host string
wantIP netip.Addr
wantRes resultCode
isLocalCli bool
}{{
wantIP: knownIP,
name: "local_client_success",
host: "example.lan",
wantIP: knownIP,
wantRes: resultCodeSuccess,
isLocalCli: true,
}, {
wantIP: netip.Addr{},
name: "local_client_unknown_host",
host: "wronghost.lan",
wantIP: netip.Addr{},
wantRes: resultCodeSuccess,
isLocalCli: true,
}, {
wantIP: netip.Addr{},
name: "external_client_known_host",
host: "example.lan",
wantIP: netip.Addr{},
wantRes: resultCodeFinish,
isLocalCli: false,
}, {
wantIP: netip.Addr{},
name: "external_client_unknown_host",
host: "wronghost.lan",
wantIP: netip.Addr{},
wantRes: resultCodeFinish,
isLocalCli: false,
}}
@@ -332,52 +423,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
knownIP := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
wantIP netip.Addr
name string
host string
suffix string
wantIP netip.Addr
wantRes resultCode
qtyp uint16
}{{
wantIP: netip.Addr{},
name: "success_external",
host: examplecom,
suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_external_non_a",
host: examplecom,
suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess,
qtyp: dns.TypeCNAME,
}, {
wantIP: knownIP,
name: "success_internal",
host: examplelan,
suffix: defaultLocalDomainSuffix,
wantIP: knownIP,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_internal_unknown",
host: "example-new.lan",
suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}, {
wantIP: netip.Addr{},
name: "success_internal_aaaa",
host: examplelan,
suffix: defaultLocalDomainSuffix,
wantIP: netip.Addr{},
wantRes: resultCodeSuccess,
qtyp: dns.TypeAAAA,
}, {
wantIP: knownIP,
name: "success_custom_suffix",
host: "example.custom",
suffix: "custom",
wantIP: knownIP,
wantRes: resultCodeSuccess,
qtyp: dns.TypeA,
}}
@@ -560,10 +651,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
},
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
Addr: testClientAddr,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,

View File

@@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings() (err error) {
// We're setting a customized set of RootCAs. The reason is that Go default
// mechanism of loading TLS roots does not always work properly on some
// routers so we're loading roots manually and pass it here.
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
upstream.RootCAs = s.conf.TLSv12Roots
upstream.CipherSuites = s.conf.TLSCiphers
@@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts(
// extractUpstreamHost returns the hostname of addr without port with an
// assumption that any address passed here has already been successfully parsed
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
func extractUpstreamHost(addr string) (host string) {
var err error

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -83,53 +84,53 @@ func (d *DNSFilter) filterSetProperties(
filters = d.WhitelistFilters
}
i := slices.IndexFunc(filters, func(filt FilterYAML) bool { return filt.URL == listURL })
i := slices.IndexFunc(filters, func(flt FilterYAML) bool { return flt.URL == listURL })
if i == -1 {
return false, errFilterNotExist
}
filt := &filters[i]
flt := &filters[i]
log.Debug(
"filtering: set name to %q, url to %s, enabled to %t for filter %s",
newList.Name,
newList.URL,
newList.Enabled,
filt.URL,
flt.URL,
)
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time, oldRulesCount int) {
if err != nil {
filt.URL = oldURL
filt.Name = oldName
filt.Enabled = oldEnabled
filt.LastUpdated = oldUpdated
filt.RulesCount = oldRulesCount
flt.URL = oldURL
flt.Name = oldName
flt.Enabled = oldEnabled
flt.LastUpdated = oldUpdated
flt.RulesCount = oldRulesCount
}
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated, filt.RulesCount)
}(flt.URL, flt.Name, flt.Enabled, flt.LastUpdated, flt.RulesCount)
filt.Name = newList.Name
flt.Name = newList.Name
if filt.URL != newList.URL {
if flt.URL != newList.URL {
if d.filterExistsLocked(newList.URL) {
return false, errFilterExists
}
shouldRestart = true
filt.URL = newList.URL
filt.LastUpdated = time.Time{}
filt.unload()
flt.URL = newList.URL
flt.LastUpdated = time.Time{}
flt.unload()
}
if filt.Enabled != newList.Enabled {
filt.Enabled = newList.Enabled
if flt.Enabled != newList.Enabled {
flt.Enabled = newList.Enabled
shouldRestart = true
}
if filt.Enabled {
if flt.Enabled {
if shouldRestart {
// Download the filter contents.
shouldRestart, err = d.update(filt)
shouldRestart, err = d.update(flt)
}
} else {
// TODO(e.burkov): The validation of the contents of the new URL is
@@ -137,7 +138,7 @@ func (d *DNSFilter) filterSetProperties(
// possible to set a bad rules source, but the validation should still
// kick in when the filter is enabled. Consider changing this behavior
// to be stricter.
filt.unload()
flt.unload()
}
return shouldRestart, err
@@ -250,24 +251,24 @@ func assignUniqueFilterID() int64 {
// Sets up a timer that will be checking for filters updates periodically
func (d *DNSFilter) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval
ivl := 5 // use a dynamically increasing time interval
for {
isNetErr, ok := false, false
if d.FiltersUpdateIntervalHours != 0 {
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
if ok && !isNetErr {
intval = maxInterval
ivl = maxInterval
}
}
if isNetErr {
intval *= 2
if intval > maxInterval {
intval = maxInterval
ivl *= 2
if ivl > maxInterval {
ivl = maxInterval
}
}
time.Sleep(time.Duration(intval) * time.Second)
time.Sleep(time.Duration(ivl) * time.Second)
}
}
@@ -329,20 +330,20 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int,
return 0, nil, nil, false
}
nfail := 0
failNum := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated, err := d.update(uf)
updateFlags = append(updateFlags, updated)
if err != nil {
nfail++
log.Info("filtering: updating filter from url %q: %s\n", uf.URL, err)
failNum++
log.Error("filtering: updating filter from url %q: %s\n", uf.URL, err)
continue
}
}
if nfail == len(updateFilters) {
if failNum == len(updateFilters) {
return 0, nil, nil, true
}
@@ -464,48 +465,6 @@ func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) {
return b, err
}
// finalizeUpdate closes and gets rid of temporary file f with filter's content
// according to updated. It also saves new values of flt's name, rules number
// and checksum if succeeded.
func (d *DNSFilter) finalizeUpdate(
file *os.File,
flt *FilterYAML,
updated bool,
res *rulelist.ParseResult,
) (err error) {
tmpFileName := file.Name()
// Close the file before renaming it because it's required on Windows.
//
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
err = file.Close()
if err != nil {
return fmt.Errorf("closing temporary file: %w", err)
}
if !updated {
log.Debug("filtering: filter %d from url %q has no changes, skipping", flt.ID, flt.URL)
return os.Remove(tmpFileName)
}
fltPath := flt.Path(d.DataDir)
log.Info("filtering: saving contents of filter %d into %q", flt.ID, fltPath)
// Don't use renameio or maybe packages, since those will require loading
// the whole filter content to the memory on Windows.
err = os.Rename(tmpFileName, fltPath)
if err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName))
}
flt.Name = aghalg.Coalesce(flt.Name, res.Title)
flt.checksum, flt.RulesCount = res.Checksum, res.RulesCount
return nil
}
// updateIntl updates the flt rewriting it's actual file. It returns true if
// the actual update has been performed.
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
@@ -513,63 +472,22 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
var res *rulelist.ParseResult
var tmpFile *os.File
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
if err != nil {
return false, err
}
defer func() {
finErr := d.finalizeUpdate(tmpFile, flt, ok, res)
if ok && finErr == nil {
log.Info(
"filtering: updated filter %d: %d bytes, %d rules",
flt.ID,
res.BytesWritten,
res.RulesCount,
)
return
}
err = errors.WithDeferred(err, finErr)
}()
// Change the default 0o600 permission to something more acceptable by end
// users.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3198.
if err = tmpFile.Chmod(0o644); err != nil {
return false, fmt.Errorf("changing file mode: %w", err)
tmpFile, err := aghrenameio.NewPendingFile(flt.Path(d.DataDir), 0o644)
if err != nil {
return false, err
}
defer func() { err = d.finalizeUpdate(tmpFile, flt, res, err, ok) }()
var r io.Reader
if !filepath.IsAbs(flt.URL) {
var resp *http.Response
resp, err = d.HTTPClient.Get(flt.URL)
if err != nil {
log.Info("filtering: requesting filter from %q: %s, skipping", flt.URL, err)
return false, err
}
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK {
log.Info("filtering got status code %d from %q, skipping", resp.StatusCode, flt.URL)
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
}
r = resp.Body
} else {
var f *os.File
f, err = os.Open(flt.URL)
if err != nil {
return false, fmt.Errorf("open file: %w", err)
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
r = f
r, err := d.reader(flt.URL)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return false, err
}
defer func() { err = errors.WithDeferred(err, r.Close()) }()
bufPtr := d.bufPool.Get().(*[]byte)
defer d.bufPool.Put(bufPtr)
@@ -580,6 +498,78 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return res.Checksum != flt.checksum && err == nil, err
}
// finalizeUpdate closes and gets rid of temporary file f with filter's content
// according to updated. It also saves new values of flt's name, rules number
// and checksum if succeeded.
func (d *DNSFilter) finalizeUpdate(
file aghrenameio.PendingFile,
flt *FilterYAML,
res *rulelist.ParseResult,
returned error,
updated bool,
) (err error) {
id := flt.ID
if !updated {
if returned == nil {
log.Debug("filtering: filter %d from url %q has no changes, skipping", id, flt.URL)
}
return errors.WithDeferred(returned, file.Cleanup())
}
log.Info("filtering: saving contents of filter %d into %q", id, flt.Path(d.DataDir))
err = file.CloseReplace()
if err != nil {
return fmt.Errorf("finalizing update: %w", err)
}
rulesCount := res.RulesCount
log.Info("filtering: updated filter %d: %d bytes, %d rules", id, res.BytesWritten, rulesCount)
flt.Name = aghalg.Coalesce(flt.Name, res.Title)
flt.checksum = res.Checksum
flt.RulesCount = rulesCount
return nil
}
// reader returns an io.ReadCloser reading filtering-rule list data form either
// a file on the filesystem or the filter's HTTP URL.
func (d *DNSFilter) reader(fltURL string) (r io.ReadCloser, err error) {
if !filepath.IsAbs(fltURL) {
r, err = d.readerFromURL(fltURL)
if err != nil {
return nil, fmt.Errorf("reading from url: %w", err)
}
return r, nil
}
r, err = os.Open(fltURL)
if err != nil {
return nil, fmt.Errorf("opening file: %w", err)
}
return r, nil
}
// readerFromURL returns an io.ReadCloser reading filtering-rule list data form
// the filter's URL.
func (d *DNSFilter) readerFromURL(fltURL string) (r io.ReadCloser, err error) {
resp, err := d.HTTPClient.Get(fltURL)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
}
return resp.Body, nil
}
// loads filter contents from the file in dataDir
func (d *DNSFilter) load(flt *FilterYAML) (err error) {
fileName := flt.Path(d.DataDir)

View File

@@ -943,7 +943,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{
bufPool: &sync.Pool{
New: func() (buf any) {
bufVal := make([]byte, rulelist.MaxRuleLen)
bufVal := make([]byte, rulelist.DefaultRuleBufSize)
return &bufVal
},

View File

@@ -6,9 +6,9 @@ import (
"fmt"
"hash/crc32"
"io"
"unicode"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/slices"
)
// Parser is a filtering-rule parser that collects data, such as the checksum
@@ -48,19 +48,29 @@ type ParseResult struct {
// nil.
func (p *Parser) Parse(dst io.Writer, src io.Reader, buf []byte) (r *ParseResult, err error) {
s := bufio.NewScanner(src)
s.Buffer(buf, MaxRuleLen)
lineIdx := 0
// Don't use [DefaultRuleBufSize] as the maximum size, since some
// filtering-rule lists compressed by e.g. HostlistsCompiler can have very
// large lines. The buffer optimization still works for the more common
// case of reasonably-sized lines.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/6003.
s.Buffer(buf, bufio.MaxScanTokenSize)
// Use a one-based index for lines and columns, since these errors end up in
// the frontend, and users are more familiar with one-based line and column
// indexes.
lineNum := 1
for s.Scan() {
var n int
n, err = p.processLine(dst, s.Bytes(), lineIdx)
n, err = p.processLine(dst, s.Bytes(), lineNum)
p.written += n
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return p.result(), err
}
lineIdx++
lineNum++
}
r = p.result()
@@ -81,7 +91,7 @@ func (p *Parser) result() (r *ParseResult) {
// processLine processes a single line. It may write to dst, and if it does, n
// is the number of bytes written.
func (p *Parser) processLine(dst io.Writer, line []byte, lineIdx int) (n int, err error) {
func (p *Parser) processLine(dst io.Writer, line []byte, lineNum int) (n int, err error) {
trimmed := bytes.TrimSpace(line)
if p.written == 0 && isHTMLLine(trimmed) {
return 0, ErrHTML
@@ -95,9 +105,10 @@ func (p *Parser) processLine(dst io.Writer, line []byte, lineIdx int) (n int, er
}
if badIdx != -1 {
return 0, fmt.Errorf(
"line at index %d: character at index %d: non-printable character",
lineIdx,
badIdx+bytes.Index(line, trimmed),
"line %d: character %d: likely binary character %q",
lineNum,
badIdx+bytes.Index(line, trimmed)+1,
trimmed[badIdx],
)
}
@@ -130,41 +141,37 @@ func hasPrefixFold(b, prefix []byte) (ok bool) {
}
// parseLine returns true if the parsed line is a filtering rule. line is
// assumed to be trimmed of whitespace characters. nonPrintIdx is the index of
// the first non-printable character, if any; if there are none, nonPrintIdx is
// -1.
// assumed to be trimmed of whitespace characters. badIdx is the index of the
// first character that may indicate that this is a binary file, or -1 if none.
//
// A line is considered a rule if it's not empty, not a comment, and contains
// only printable characters.
func parseLine(line []byte) (nonPrintIdx int, isRule bool) {
func parseLine(line []byte) (badIdx int, isRule bool) {
if len(line) == 0 || line[0] == '#' || line[0] == '!' {
return -1, false
}
nonPrintIdx = bytes.IndexFunc(line, isNotPrintable)
badIdx = slices.IndexFunc(line, likelyBinary)
return nonPrintIdx, nonPrintIdx == -1
return badIdx, badIdx == -1
}
// isNotPrintable returns true if r is not a printable character that can be
// contained in a filtering rule.
func isNotPrintable(r rune) (ok bool) {
// Tab isn't included into Unicode's graphic symbols, so include it here
// explicitly.
return r != '\t' && !unicode.IsGraphic(r)
// likelyBinary returns true if b is likely to be a byte from a binary file.
func likelyBinary(b byte) (ok bool) {
return (b < ' ' || b == 0x7f) && b != '\n' && b != '\r' && b != '\t'
}
// parseLineTitle is like [parseLine] but additionally looks for a title. line
// is assumed to be trimmed of whitespace characters.
func (p *Parser) parseLineTitle(line []byte) (nonPrintIdx int, isRule bool) {
func (p *Parser) parseLineTitle(line []byte) (badIdx int, isRule bool) {
if len(line) == 0 || line[0] == '#' {
return -1, false
}
if line[0] != '!' {
nonPrintIdx = bytes.IndexFunc(line, isNotPrintable)
badIdx = slices.IndexFunc(line, likelyBinary)
return nonPrintIdx, nonPrintIdx == -1
return badIdx, badIdx == -1
}
const titlePattern = "! Title: "

View File

@@ -6,10 +6,10 @@ import (
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakeio"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -17,6 +17,9 @@ import (
func TestParser_Parse(t *testing.T) {
t.Parallel()
longRule := strings.Repeat("a", rulelist.DefaultRuleBufSize+1) + "\n"
tooLongRule := strings.Repeat("a", bufio.MaxScanTokenSize+1) + "\n"
testCases := []struct {
name string
in string
@@ -74,26 +77,42 @@ func TestParser_Parse(t *testing.T) {
wantTitle: "Test Title",
wantRulesNum: 1,
wantWritten: len(testRuleTextBlocked),
}, {
name: "cosmetic_with_zwnj",
in: testRuleTextCosmetic,
wantDst: testRuleTextCosmetic,
wantErrMsg: "",
wantTitle: "",
wantRulesNum: 1,
wantWritten: len(testRuleTextCosmetic),
}, {
name: "bad_char",
in: "! Title: Test Title \n" +
testRuleTextBlocked +
">>>\x7F<<<",
wantDst: testRuleTextBlocked,
wantErrMsg: "line at index 2: " +
"character at index 3: " +
"non-printable character",
wantErrMsg: "line 3: " +
"character 4: " +
"likely binary character '\\x7f'",
wantTitle: "Test Title",
wantRulesNum: 1,
wantWritten: len(testRuleTextBlocked),
}, {
name: "too_long",
in: strings.Repeat("a", rulelist.MaxRuleLen+1),
in: tooLongRule,
wantDst: "",
wantErrMsg: "scanning filter contents: " + bufio.ErrTooLong.Error(),
wantErrMsg: "scanning filter contents: bufio.Scanner: token too long",
wantTitle: "",
wantRulesNum: 0,
wantWritten: 0,
}, {
name: "longer_than_default",
in: longRule,
wantDst: longRule,
wantErrMsg: "",
wantTitle: "",
wantRulesNum: 1,
wantWritten: len(longRule),
}, {
name: "bad_tab_and_comment",
in: testRuleTextBadTab,
@@ -118,7 +137,7 @@ func TestParser_Parse(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
buf := make([]byte, rulelist.MaxRuleLen)
buf := make([]byte, rulelist.DefaultRuleBufSize)
p := rulelist.NewParser()
r, err := p.Parse(dst, strings.NewReader(tc.in), buf)
@@ -140,12 +159,12 @@ func TestParser_Parse(t *testing.T) {
func TestParser_Parse_writeError(t *testing.T) {
t.Parallel()
dst := &aghtest.Writer{
dst := &fakeio.Writer{
OnWrite: func(b []byte) (n int, err error) {
return 1, errors.Error("test error")
},
}
buf := make([]byte, rulelist.MaxRuleLen)
buf := make([]byte, rulelist.DefaultRuleBufSize)
p := rulelist.NewParser()
r, err := p.Parse(dst, strings.NewReader(testRuleTextBlocked), buf)
@@ -165,7 +184,7 @@ func TestParser_Parse_checksums(t *testing.T) {
"# Another comment.\n"
)
buf := make([]byte, rulelist.MaxRuleLen)
buf := make([]byte, rulelist.DefaultRuleBufSize)
p := rulelist.NewParser()
r, err := p.Parse(&bytes.Buffer{}, strings.NewReader(withoutComments), buf)
@@ -192,7 +211,7 @@ var (
func BenchmarkParser_Parse(b *testing.B) {
dst := &bytes.Buffer{}
src := strings.NewReader(strings.Repeat(testRuleTextBlocked, 1000))
buf := make([]byte, rulelist.MaxRuleLen)
buf := make([]byte, rulelist.DefaultRuleBufSize)
p := rulelist.NewParser()
b.ReportAllocs()
@@ -204,6 +223,14 @@ func BenchmarkParser_Parse(b *testing.B) {
require.NoError(b, errSink)
require.NotNil(b, resSink)
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
// BenchmarkParser_Parse-16 100000000 128.0 ns/op 48 B/op 1 allocs/op
}
func FuzzParser_Parse(f *testing.F) {
@@ -215,15 +242,17 @@ func FuzzParser_Parse(f *testing.F) {
"! Comment",
"! Title ",
"! Title XXX",
testRuleTextBadTab,
testRuleTextBlocked,
testRuleTextCosmetic,
testRuleTextEtcHostsTab,
testRuleTextHTML,
testRuleTextBlocked,
testRuleTextBadTab,
"1.2.3.4",
"1.2.3.4 etc-hosts.example",
">>>\x00<<<",
">>>\x7F<<<",
strings.Repeat("a", n+1),
strings.Repeat("a", rulelist.DefaultRuleBufSize+1),
strings.Repeat("a", bufio.MaxScanTokenSize+1),
}
for _, tc := range testCases {

View File

@@ -4,8 +4,6 @@
// TODO(a.garipov): Expand.
package rulelist
// MaxRuleLen is the maximum length of a line with a filtering rule, in bytes.
//
// TODO(a.garipov): Consider changing this to a rune length, like AdGuardDNS
// does.
const MaxRuleLen = 1024
// DefaultRuleBufSize is the default length of a buffer used to read a line with
// a filtering rule, in bytes.
const DefaultRuleBufSize = 1024

View File

@@ -7,8 +7,13 @@ const testTimeout = 1 * time.Second
// Common texts for tests.
const (
testRuleTextHTML = "<!DOCTYPE html>\n"
testRuleTextBlocked = "||blocked.example^\n"
testRuleTextBadTab = "||bad-tab-and-comment.example^\t# A comment.\n"
testRuleTextBlocked = "||blocked.example^\n"
testRuleTextEtcHostsTab = "0.0.0.0 tab..example^\t# A comment.\n"
testRuleTextHTML = "<!DOCTYPE html>\n"
// testRuleTextCosmetic is a cosmetic rule with a zero-width non-joiner.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/6003.
testRuleTextCosmetic = "||cosmetic.example## :has-text(/\u200c/i)\n"
)

View File

@@ -89,37 +89,34 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
assert.True(t, res.Rules[0].IP.Equal(wantIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
assert.True(t, cachedValue.Rules[0].IP.Equal(wantIP))
}
const googleHost = "www.google.com"

View File

@@ -92,8 +92,15 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
}
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{}
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf
conf.CustomResolver = resolver
@@ -119,7 +126,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP)
assert.Equal(t, wantIP, res.Rules[0].IP)
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
})
}

View File

@@ -1505,7 +1505,6 @@ var blockedServices = []blockedService{{
"||aus.social^",
"||awscommunity.social^",
"||climatejustice.social^",
"||cupoftea.social^",
"||cyberplace.social^",
"||defcon.social^",
"||det.social^",
@@ -1589,6 +1588,7 @@ var blockedServices = []blockedService{{
"||techhub.social^",
"||theblower.au^",
"||tkz.one^",
"||todon.eu^",
"||toot.aquilenet.fr^",
"||toot.community^",
"||toot.funami.tech^",
@@ -1661,6 +1661,7 @@ var blockedServices = []blockedService{{
"||nintendo.jp^",
"||nintendo.net^",
"||nintendo.nl^",
"||nintendo.pt^",
"||nintendoswitch.cn^",
"||nintendowifi.net^",
},
@@ -2160,6 +2161,20 @@ var blockedServices = []blockedService{{
Rules: []string{
"||voot.com^",
},
}, {
ID: "wargaming",
Name: "Wargaming",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 24 24\"><path d=\"M12 1.998c-5.52 0-10 4.481-10 9.988 0 5.52 4.48 9.996 10 9.996s10-4.476 10-9.996c0-5.507-4.48-9.988-10-9.988zm0 2c4.413 0 8 3.588 8 7.988 0 3.246-1.944 6.04-4.727 7.293.54-1.861.831-3.988.807-6.226l1.414.414a23.648 23.648 0 0 0-2-4.041c-.627 1.347-1.48 2.56-2.52 3.68l1.68-.133c-1.507 2.92-3.134 3.906-5.547 4.013-.386-4.213.12-7.014 2.827-9.04l.386 1.493c.653-.974 1.36-2.12 2.373-2.947-1.506-.6-2.999-.627-4.492-.334.386.16.76.588 1.014.828-3.485 1.662-5.643 4.202-6.744 7.68A7.95 7.95 0 0 1 4 11.986c0-4.4 3.587-7.988 8-7.988z\"/></svg>"),
Rules: []string{
"||wargaming.com^",
"||wargaming.net^",
"||wgcdn.co^",
"||wgcrowd.io^",
"||worldoftanks.com^",
"||worldofwarplanes.com^",
"||worldofwarships.eu^",
"||wotblitz.com^",
},
}, {
ID: "wechat",
Name: "WeChat",

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@@ -141,7 +142,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
}
}
// webHandlersRegistered prevents a [clientsContainer] from regisering its web
// webHandlersRegistered prevents a [clientsContainer] from registering its web
// handlers more than once.
//
// TODO(a.garipov): Refactor HTTP handler registration logic.
@@ -743,11 +744,9 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
return nil
}
// setWHOISInfo sets the WHOIS information for a client.
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
// expected to be locked.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip.String())
if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip)
@@ -774,9 +773,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
rc.WHOIS = wi
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
// addHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(
//
// TODO(a.garipov): Only used in internal tests. Consider removing.
func (clients *clientsContainer) addHost(
ip netip.Addr,
host string,
src clientSource,
@@ -787,6 +788,32 @@ func (clients *clientsContainer) AddHost(
return clients.addHostLocked(ip, host, src)
}
// type check
var _ client.AddressUpdater = (*clientsContainer)(nil)
// UpdateAddress implements the [client.AddressUpdater] interface for
// *clientsContainer
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
// Common fast path optimization.
if host == "" && info == nil {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
if host != "" {
ok := clients.addHostLocked(ip, host, ClientSourceRDNS)
if !ok {
log.Debug("clients: host for client %q already set with higher priority source", ip)
}
}
if info != nil {
clients.setWHOISInfo(ip, info)
}
}
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
// locked.
func (clients *clientsContainer) addHostLocked(

View File

@@ -168,13 +168,13 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host", ClientSourceARP)
ok := clients.addHost(ip, "host", ClientSourceARP)
assert.True(t, ok)
ok = clients.AddHost(ip, "host2", ClientSourceARP)
ok = clients.addHost(ip, "host2", ClientSourceARP)
assert.True(t, ok)
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
ok = clients.addHost(ip, "host3", ClientSourceHostsFile)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
@@ -182,18 +182,18 @@ func TestClients(t *testing.T) {
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
ok := clients.addHost(ip, "from_arp", ClientSourceARP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
ok = clients.addHost(ip, "from_dhcp", ClientSourceDHCP)
assert.True(t, ok)
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
})
t.Run("addhost_fail", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
ok := clients.addHost(ip, "host1", ClientSourceRDNS)
assert.False(t, ok)
})
}
@@ -216,7 +216,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
ok := clients.addHost(ip, "host", ClientSourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois)
@@ -259,7 +259,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
ok = clients.addHost(ip, "test", ClientSourceRDNS)
assert.True(t, ok)
})

View File

@@ -20,7 +20,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
"github.com/google/renameio/v2/maybe"
"golang.org/x/exp/slices"
yaml "gopkg.in/yaml.v3"
)
@@ -590,7 +590,13 @@ func (c *configuration) write() (err error) {
s.WriteDiskConfig(&c)
dns := &config.DNS
dns.FilteringConfig = c
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
dns.LocalPTRResolvers = s.LocalPTRResolvers()
addrProcConf := s.AddrProcConfig()
config.Clients.Sources.RDNS = addrProcConf.UseRDNS
config.Clients.Sources.WHOIS = addrProcConf.UseWHOIS
dns.UsePrivateRDNS = addrProcConf.UsePrivateRDNS
}
if Context.dhcpServer != nil {

View File

@@ -176,12 +176,16 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// ------------------------
// registration of handlers
// ------------------------
func registerControlHandlers() {
func registerControlHandlers(web *webAPI) {
Context.mux.HandleFunc(
"/control/version.json",
postInstall(optionalAuth(web.handleVersionJSON)),
)
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)

View File

@@ -448,7 +448,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
web.conf.BindHost = req.Web.IP
web.conf.BindPort = req.Web.Port
registerControlHandlers()
registerControlHandlers(web)
aghhttp.OK(w)
if f, ok := w.(http.Flusher); ok {

View File

@@ -29,9 +29,9 @@ type temporaryError interface {
// handleVersionJSON is the handler for the POST /control/version.json HTTP API.
//
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
resp := &versionResponse{}
if Context.disableUpdate {
if web.conf.disableUpdate {
resp.Disabled = true
_ = aghhttp.WriteJSONResponse(w, r, resp)
@@ -52,7 +52,7 @@ func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
}
}
err = requestVersionInfo(resp, req.Recheck)
err = web.requestVersionInfo(resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
@@ -73,9 +73,10 @@ func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
// update server.
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
updater := web.conf.updater
for i := 0; i != 3; i++ {
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
resp.VersionInfo, err = updater.VersionInfo(recheck)
if err != nil {
var terr temporaryError
if errors.As(err, &terr) && terr.Temporary() {
@@ -95,7 +96,7 @@ func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
}
if err != nil {
vcu := Context.updater.VersionCheckURL()
vcu := updater.VersionCheckURL()
return fmt.Errorf("getting version info from %s: %w", vcu, err)
}
@@ -104,8 +105,9 @@ func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
}
// handleUpdate performs an update to the latest available version procedure.
func handleUpdate(w http.ResponseWriter, r *http.Request) {
if Context.updater.NewVersion() == "" {
func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
updater := web.conf.updater
if updater.NewVersion() == "" {
aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now")
return
@@ -122,7 +124,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
return
}
err = Context.updater.Update(false)
err = updater.Update(false)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -137,7 +139,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
// The background context is used because the underlying functions wrap it
// with timeout and shut down the server, which handles current request. It
// also should be done in a separate goroutine for the same reason.
go finishUpdate(context.Background(), execPath)
go finishUpdate(context.Background(), execPath, web.conf.runningAsService)
}
// versionResponse is the response for /control/version.json endpoint.
@@ -178,7 +180,7 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
}
// finishUpdate completes an update procedure.
func finishUpdate(ctx context.Context, execPath string) {
func finishUpdate(ctx context.Context, execPath string, runningAsService bool) {
var err error
log.Info("stopping all tasks")
@@ -187,7 +189,7 @@ func finishUpdate(ctx context.Context, execPath string) {
cleanupAlways()
if runtime.GOOS == "windows" {
if Context.runningAsService {
if runningAsService {
// NOTE: We can't restart the service via "kardianos/service"
// package, because it kills the process first we can't start a new
// instance, because Windows doesn't allow it.

View File

@@ -13,14 +13,12 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -135,7 +133,7 @@ func initDNSServer(
return fmt.Errorf("preparing set of private subnets: %w", err)
}
p := dnsforward.DNSCreateParams{
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
DNSFilter: filters,
Stats: sts,
QueryLog: qlog,
@@ -143,9 +141,7 @@ func initDNSServer(
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
DHCPServer: dhcpSrv,
}
Context.dnsServer, err = dnsforward.NewServer(p)
})
if err != nil {
closeDNSServer()
@@ -154,134 +150,23 @@ func initDNSServer(
Context.clients.dnsServer = Context.dnsServer
dnsConf, err := generateServerConfig(tlsConf, httpReg)
dnsConf, err := newServerConfig(tlsConf, httpReg)
if err != nil {
closeDNSServer()
return fmt.Errorf("generateServerConfig: %w", err)
return fmt.Errorf("newServerConfig: %w", err)
}
err = Context.dnsServer.Prepare(&dnsConf)
err = Context.dnsServer.Prepare(dnsConf)
if err != nil {
closeDNSServer()
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
initRDNS()
initWHOIS()
return nil
}
const (
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
// processing.
defaultQueueSize = 255
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
// processing. It must be greater than zero.
defaultCacheSize = 10_000
// defaultIPTTL is the Time to Live duration for IP addresses cached by
// rDNS and WHOIS.
defaultIPTTL = 1 * time.Hour
)
// initRDNS initializes the rDNS.
func initRDNS() {
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
// TODO(s.chzhen): Add ability to disable it on dns server configuration
// update in [dnsforward] package.
r := rdns.New(&rdns.Config{
Exchanger: Context.dnsServer,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
go processRDNS(r)
}
// processRDNS processes reverse DNS lookup queries. It is intended to be used
// as a goroutine.
func processRDNS(r rdns.Interface) {
defer log.OnPanic("rdns")
for ip := range Context.rdnsCh {
ok := Context.dnsServer.ShouldResolveClient(ip)
if !ok {
continue
}
host, changed := r.Process(ip)
if host == "" || !changed {
continue
}
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
if ok {
continue
}
log.Debug(
"dns: can't set rdns info for client %q: already set with higher priority source",
ip,
)
}
}
// initWHOIS initializes the WHOIS.
//
// TODO(s.chzhen): Consider making configurable.
func initWHOIS() {
const (
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultMaxConnReadSize is an upper limit in bytes for reading from
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
)
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
var w whois.Interface
if config.Clients.Sources.WHOIS {
w = whois.New(&whois.Config{
DialContext: customDialContext,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
} else {
w = whois.Empty{}
}
go func() {
defer log.OnPanic("whois")
for ip := range Context.whoisCh {
info, changed := w.Process(context.Background(), ip)
if info != nil && changed {
Context.clients.setWHOISInfo(ip, info)
}
}
}()
}
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
// a subnet set that matches all locally served networks, see
// [netutil.IsLocallyServed].
@@ -312,17 +197,6 @@ func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}
func onDNSRequest(pctx *proxy.DNSContext) {
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
if ip == (netip.Addr{}) {
// This would be quite weird if we get here.
return
}
Context.rdnsCh <- ip
Context.whoisCh <- ip
}
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
if ips == nil {
return nil
@@ -349,23 +223,41 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
return udpAddrs
}
func generateServerConfig(
func newServerConfig(
tlsConf *tlsConfigSettings,
httpReg aghhttp.RegisterFunc,
) (newConf dnsforward.ServerConfig, err error) {
) (newConf *dnsforward.ServerConfig, err error) {
dnsConf := config.DNS
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
newConf = dnsforward.ServerConfig{
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
FilteringConfig: dnsConf.FilteringConfig,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
OnDNSRequest: onDNSRequest,
UseDNS64: config.DNS.UseDNS64,
DNS64Prefixes: config.DNS.DNS64Prefixes,
}
var initialAddresses []netip.Addr
// Context.stats may be nil here if initDNSServer is called from
// [cmdlineUpdate].
if sts := Context.stats; sts != nil {
const initialClientsNum = 100
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
}
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
// are set by [dnsforward.Server.Prepare].
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
Exchanger: Context.dnsServer,
AddressUpdater: &Context.clients,
InitialAddresses: initialAddresses,
UseRDNS: config.Clients.Sources.RDNS,
UseWHOIS: config.Clients.Sources.WHOIS,
}
if tlsConf.Enabled {
newConf.TLSConfig = tlsConf.TLSConfig
newConf.TLSConfig.ServerName = tlsConf.ServerName
@@ -385,9 +277,9 @@ func generateServerConfig(
if tlsConf.PortDNSCrypt != 0 {
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
if err != nil {
// Don't wrap the error, because it's already
// wrapped by newDNSCrypt.
return dnsforward.ServerConfig{}, err
// Don't wrap the error, because it's already wrapped by
// newDNSCrypt.
return nil, err
}
}
}
@@ -401,7 +293,6 @@ func generateServerConfig(
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
newConf.ResolveClients = config.Clients.Sources.RDNS
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
newConf.ServeHTTP3 = dnsConf.ServeHTTP3
newConf.UseHTTP3Upstreams = dnsConf.UseHTTP3Upstreams
@@ -556,27 +447,19 @@ func startDNSServer() error {
Context.stats.Start()
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
Context.rdnsCh <- ip
Context.whoisCh <- ip
}
return nil
}
func reconfigureDNSServer() (err error) {
var newConf dnsforward.ServerConfig
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
newConf, err = generateServerConfig(tlsConf, httpRegister)
newConf, err := newServerConfig(tlsConf, httpRegister)
if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = Context.dnsServer.Reconfigure(&newConf)
err = Context.dnsServer.Reconfigure(newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}

View File

@@ -3,14 +3,12 @@ package home
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/fs"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -66,40 +64,24 @@ type homeContext struct {
// configuration files, for example /etc/hosts.
etcHosts *aghnet.HostsContainer
updater *updater.Updater
// mux is our custom http.ServeMux.
mux *http.ServeMux
// Runtime properties
// --
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// rdnsCh is the channel for receiving IPs for rDNS processing.
rdnsCh chan netip.Addr
// whoisCh is the channel for receiving IPs for WHOIS processing.
whoisCh chan netip.Addr
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16
// disableUpdate, if true, tells AdGuard Home to not check for updates.
disableUpdate bool
// firstRun, if true, tells AdGuard Home to only start the web interface
// service, and only serve the first-run APIs.
firstRun bool
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
}
// getDataDir returns path to the directory where we store databases and filters
@@ -122,11 +104,11 @@ func Main(clientBuildFS fs.FS) {
// package flag.
opts := loadCmdLineOpts()
Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
for {
sig := <-Context.appSignalChannel
sig := <-signals
log.Info("Received signal %q", sig)
switch sig {
case syscall.SIGHUP:
@@ -141,7 +123,7 @@ func Main(clientBuildFS fs.FS) {
}()
if opts.serviceControlAction != "" {
handleServiceControlAction(opts, clientBuildFS)
handleServiceControlAction(opts, clientBuildFS, signals)
return
}
@@ -153,74 +135,48 @@ func Main(clientBuildFS fs.FS) {
// setupContext initializes [Context] fields. It also reads and upgrades
// config file if necessary.
func setupContext(opts options) (err error) {
setupContextFlags(opts)
Context.firstRun = detectFirstRun()
Context.tlsRoots = aghtls.SystemRootCAs()
Context.client = &http.Client{
Timeout: time.Minute * 5,
Transport: &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
},
}
Context.mux = http.NewServeMux()
if !Context.firstRun {
// Do the upgrade if necessary.
err = upgradeConfig()
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
return nil
}
// Do the upgrade if necessary.
err = upgradeConfig()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err)
os.Exit(1)
}
if opts.checkConfig {
log.Info("configuration file is ok")
os.Exit(0)
}
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err)
os.Exit(1)
}
if opts.checkConfig {
log.Info("configuration file is ok")
os.Exit(0)
}
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
}
return nil
}
// setupContextFlags sets global flags and prints their status to the log.
func setupContextFlags(opts options) {
Context.firstRun = detectFirstRun()
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
}
Context.runningAsService = opts.runningAsService
// Don't print the runningAsService flag, since that has already been done
// in [run].
Context.disableUpdate = opts.disableUpdate || version.Channel() == version.ChannelDevelopment
if Context.disableUpdate {
log.Info("AdGuard Home updates are disabled")
}
}
// logIfUnsupported logs a formatted warning if the error is one of the
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherwise, it returns err.
@@ -325,7 +281,7 @@ func initContextClients() (err error) {
return err
}
//lint:ignore SA1019 Migration is not over.
//lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = Context.workDir
config.DHCP.DataDir = Context.getDataDir()
config.DHCP.HTTPRegister = httpRegister
@@ -340,18 +296,6 @@ func initContextClients() (err error) {
return fmt.Errorf("initing dhcp: %w", err)
}
Context.updater = updater.NewUpdater(&updater.Config{
Client: Context.client,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
})
var arpdb aghnet.ARPDB
if config.Clients.Sources.ARP {
arpdb = aghnet.NewARPDB()
@@ -433,7 +377,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules)
conf.HTTPClient = Context.client
conf.HTTPClient = httpClient()
cacheTime := time.Duration(conf.CacheTime) * time.Minute
@@ -515,7 +459,7 @@ func checkPorts() (err error) {
return nil
}
func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webAPI, err error) {
var clientFS fs.FS
if opts.localFrontend {
log.Info("warning: using local frontend files")
@@ -528,8 +472,16 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
}
}
webConf := webConfig{
firstRun: Context.firstRun,
disableUpdate := opts.disableUpdate || version.Channel() == version.ChannelDevelopment
if disableUpdate {
log.Info("AdGuard Home updates are disabled")
}
webConf := &webConfig{
updater: upd,
clientFS: clientFS,
BindHost: config.HTTPConfig.Address.Addr(),
BindPort: int(config.HTTPConfig.Address.Port()),
@@ -537,12 +489,13 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
ReadHeaderTimeout: readHdrTimeout,
WriteTimeout: writeTimeout,
clientFS: clientFS,
serveHTTP3: config.DNS.ServeHTTP3,
firstRun: Context.firstRun,
disableUpdate: disableUpdate,
runningAsService: opts.runningAsService,
serveHTTP3: config.DNS.ServeHTTP3,
}
web = newWebAPI(&webConf)
web = newWebAPI(webConf)
if web == nil {
return nil, fmt.Errorf("initializing web: %w", err)
}
@@ -593,9 +546,21 @@ func run(opts options, clientBuildFS fs.FS) {
err = setupOpts(opts)
fatalOnError(err)
upd := updater.NewUpdater(&updater.Config{
Client: config.DNS.DnsfilterConf.HTTPClient,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
})
// TODO(e.burkov): This could be made earlier, probably as the option's
// effect.
cmdlineUpdate(opts)
cmdlineUpdate(opts, upd)
if !Context.firstRun {
// Save the updated config.
@@ -624,7 +589,7 @@ func run(opts options, clientBuildFS fs.FS) {
onConfigModified()
}
Context.web, err = initWeb(opts, clientBuildFS)
Context.web, err = initWeb(opts, clientBuildFS, upd)
fatalOnError(err)
if !Context.firstRun {
@@ -634,10 +599,10 @@ func run(opts options, clientBuildFS fs.FS) {
Context.tls.start()
go func() {
sErr := startDNSServer()
if sErr != nil {
startErr := startDNSServer()
if startErr != nil {
closeDNSServer()
fatalOnError(sErr)
fatalOnError(startErr)
}
}()
@@ -996,62 +961,6 @@ func detectFirstRun() bool {
return errors.Is(err, os.ErrNotExist)
}
// Connect to a remote server resolving hostname using our own DNS server.
//
// TODO(e.burkov): This messy logic should be decomposed and clarified.
//
// TODO(a.garipov): Support network.
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
return dialer.DialContext(ctx, network, addr)
}
addrs, err := Context.dnsServer.Resolve(host)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
}
log.Debug("dnsServer.Resolve: %q: %v", host, addrs)
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %q", host)
}
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
conn, err = dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
return conn, err
}
return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
if config.ProxyURL == "" {
return nil, nil
}
return url.Parse(config.ProxyURL)
}
// jsonError is a generic JSON error response.
//
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
@@ -1062,7 +971,7 @@ type jsonError struct {
}
// cmdlineUpdate updates current application and exits.
func cmdlineUpdate(opts options) {
func cmdlineUpdate(opts options, upd *updater.Updater) {
if !opts.performUpdate {
return
}
@@ -1077,10 +986,9 @@ func cmdlineUpdate(opts options) {
log.Info("cmdline update: performing update")
updater := Context.updater
info, err := updater.VersionInfo(true)
info, err := upd.VersionInfo(true)
if err != nil {
vcu := updater.VersionCheckURL()
vcu := upd.VersionCheckURL()
log.Error("getting version info from %s: %s", vcu, err)
os.Exit(1)
@@ -1092,7 +1000,7 @@ func cmdlineUpdate(opts options) {
os.Exit(0)
}
err = updater.Update(Context.firstRun)
err = upd.Update(Context.firstRun)
fatalOnError(err)
err = restartService()

View File

@@ -0,0 +1,47 @@
package home
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"time"
)
// httpClient returns a new HTTP client that uses the AdGuard Home's own DNS
// server for resolving hostnames. The resulting client should not be used
// until [Context.dnsServer] is initialized.
//
// TODO(a.garipov, e.burkov): This is rather messy. Refactor.
func httpClient() (c *http.Client) {
// Do not use Context.dnsServer.DialContext directly in the struct literal
// below, since Context.dnsServer may be nil when this function is called.
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return Context.dnsServer.DialContext(ctx, network, addr)
}
return &http.Client{
// TODO(a.garipov): Make configurable.
Timeout: time.Minute * 5,
Transport: &http.Transport{
DialContext: dialContext,
Proxy: httpProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
},
}
}
// httpProxy returns parses and returns an HTTP proxy URL from the config, if
// any.
func httpProxy(_ *http.Request) (u *url.URL, err error) {
if config.ProxyURL == "" {
return nil, nil
}
return url.Parse(config.ProxyURL)
}

View File

@@ -1,39 +0,0 @@
package home
import (
"net/http"
"net/http/pprof"
"runtime"
"github.com/AdguardTeam/golibs/log"
)
// startPprof launches the debug and profiling server on addr.
func startPprof(addr string) {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
go func() {
defer log.OnPanic("pprof server")
log.Info("pprof: listening on %q", addr)
err := http.ListenAndServe(addr, mux)
log.Info("pprof server errors: %v", err)
}()
}

View File

@@ -33,9 +33,13 @@ const (
// daemon.
type program struct {
clientBuildFS fs.FS
signals chan os.Signal
opts options
}
// type check
var _ service.Interface = (*program)(nil)
// Start implements service.Interface interface for *program.
func (p *program) Start(_ service.Service) (err error) {
// Start should not block. Do the actual work async.
@@ -48,14 +52,14 @@ func (p *program) Start(_ service.Service) (err error) {
}
// Stop implements service.Interface interface for *program.
func (p *program) Stop(_ service.Service) error {
// Stop should not block. Return with a few seconds.
if Context.appSignalChannel == nil {
os.Exit(0)
func (p *program) Stop(_ service.Service) (err error) {
select {
case p.signals <- syscall.SIGINT:
// Go on.
default:
// Stop should not block.
}
Context.appSignalChannel <- syscall.SIGINT
return nil
}
@@ -194,7 +198,7 @@ func restartService() (err error) {
// - run: This is a special command that is not supposed to be used directly
// it is specified when we register a service, and it indicates to the app
// that it is being run as a service/daemon.
func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan os.Signal) {
// Call chooseSystem explicitly to introduce OpenBSD support for service
// package. It's a noop for other GOOS values.
chooseSystem()
@@ -226,7 +230,11 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
}
configureService(svcConfig)
s, err := service.New(&program{clientBuildFS: clientBuildFS, opts: runOpts}, svcConfig)
s, err := service.New(&program{
clientBuildFS: clientBuildFS,
signals: signals,
opts: runOpts,
}, svcConfig)
if err != nil {
log.Fatalf("service: initializing service: %s", err)
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
"github.com/google/renameio/v2/maybe"
"golang.org/x/crypto/bcrypt"
yaml "gopkg.in/yaml.v3"
)

View File

@@ -6,16 +6,18 @@ import (
"io/fs"
"net/http"
"net/netip"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/pprofutil"
"github.com/NYTimes/gziphandler"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
@@ -33,6 +35,8 @@ const (
)
type webConfig struct {
updater *updater.Updater
clientFS fs.FS
BindHost netip.Addr
@@ -52,6 +56,13 @@ type webConfig struct {
firstRun bool
// disableUpdate, if true, tells AdGuard Home to not check for updates.
disableUpdate bool
// runningAsService flag is set to true when options are passed from the
// service runner.
runningAsService bool
serveHTTP3 bool
}
@@ -102,7 +113,7 @@ func newWebAPI(conf *webConfig) (w *webAPI) {
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
w.registerInstallHandlers()
} else {
registerControlHandlers()
registerControlHandlers(w)
}
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
@@ -295,8 +306,27 @@ func (web *webAPI) mustStartHTTP3(address string) {
log.Debug("web: starting http/3 server")
err := web.httpsServer.server3.ListenAndServe()
if !errors.Is(err, quic.ErrServerClosed) {
if !errors.Is(err, http.ErrServerClosed) {
cleanupAlways()
log.Fatalf("web: http3: %s", err)
}
}
// startPprof launches the debug and profiling server on addr.
func startPprof(addr string) {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
mux := http.NewServeMux()
pprofutil.RoutePprof(mux)
go func() {
defer log.OnPanic("pprof server")
log.Info("pprof: listening on %q", addr)
err := http.ListenAndServe(addr, mux)
if !errors.Is(err, http.ErrServerClosed) {
log.Error("pprof: shutting down: %s", err)
}
}()
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
)
@@ -17,7 +18,7 @@ type Interface interface {
Process(ip netip.Addr) (host string, changed bool)
}
// Empty is an empty [Inteface] implementation which does nothing.
// Empty is an empty [Interface] implementation which does nothing.
type Empty struct{}
// type check
@@ -32,7 +33,7 @@ func (Empty) Process(_ netip.Addr) (host string, changed bool) {
type Exchanger interface {
// Exchange tries to resolve the ip in a suitable way, i.e. either as local
// or as external.
Exchange(ip netip.Addr) (host string, err error)
Exchange(ip netip.Addr) (host string, ttl time.Duration, err error)
}
// Config is the configuration structure for Default.
@@ -82,13 +83,16 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
return fromCache, false
}
host, err := r.exchanger.Exchange(ip)
host, ttl, err := r.exchanger.Exchange(ip)
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
}
// TODO(s.chzhen): Use built-in function max in Go 1.21.
ttl = mathutil.Max(ttl, r.cacheTTL)
item := &cacheItem{
expiry: time.Now().Add(r.cacheTTL),
expiry: time.Now().Add(ttl),
host: host,
}

View File

@@ -5,25 +5,13 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeRDNSExchanger is a mock [rdns.Exchanger] implementation for tests.
type fakeRDNSExchanger struct {
OnExchange func(ip netip.Addr) (host string, err error)
}
// type check
var _ rdns.Exchanger = (*fakeRDNSExchanger)(nil)
// Exchange implements [rdns.Exchanger] interface for *fakeRDNSExchanger.
func (e *fakeRDNSExchanger) Exchange(ip netip.Addr) (host string, err error) {
return e.OnExchange(ip)
}
func TestDefault_Process(t *testing.T) {
ip1 := netip.MustParseAddr("1.2.3.4")
revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice())
@@ -67,21 +55,21 @@ func TestDefault_Process(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hit := 0
onExchange := func(ip netip.Addr) (host string, err error) {
onExchange := func(ip netip.Addr) (host string, ttl time.Duration, err error) {
hit++
switch ip {
case ip1:
return revAddr1, nil
return revAddr1, 0, nil
case ip2:
return revAddr2, nil
return revAddr2, 0, nil
case localIP:
return localRevAddr1, nil
return localRevAddr1, 0, nil
default:
return "", nil
return "", 0, nil
}
}
exchanger := &fakeRDNSExchanger{
exchanger := &aghtest.Exchanger{
OnExchange: onExchange,
}

View File

@@ -9,9 +9,9 @@ require (
github.com/kisielk/errcheck v1.6.3
github.com/kyoh86/looppointer v0.2.1
github.com/securego/gosec/v2 v2.16.0
github.com/uudashr/gocognit v1.0.6
github.com/uudashr/gocognit v1.0.7
golang.org/x/tools v0.11.0
golang.org/x/vuln v0.2.0
golang.org/x/vuln v1.0.0
// TODO(a.garipov): Return to tagged releases once a new one appears.
honnef.co/go/tools v0.5.0-0.dev.0.20230709092525-bc759185c5ee
mvdan.cc/gofumpt v0.5.0
@@ -22,12 +22,12 @@ require (
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gookit/color v1.5.3 // indirect
github.com/gookit/color v1.5.4 // indirect
github.com/kyoh86/nolint v0.0.1 // indirect
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22 // indirect
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.10.0 // indirect

View File

@@ -16,8 +16,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE=
github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601 h1:mrEEilTAUmaAORhssPPkxj84TsHrPMLBGW2Z4SoTxm8=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
@@ -38,9 +38,9 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyOs5U=
github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842Y=
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/uudashr/gocognit v1.0.7 h1:e9aFXgKgUJrQ5+bs61zBigmj7bFJ/5cC6HmMahVzuDo=
github.com/uudashr/gocognit v1.0.7/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -52,8 +52,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22 h1:e8iSCQYXZ4EB6q3kIfy2fgPFTvDbozqzRe4OuIOyrL4=
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090 h1:qOYhjyK9OeXREdh7Zrta8JRvnmnFIzhkosQpp+852Ag=
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
@@ -98,8 +98,8 @@ golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8=
golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8=
golang.org/x/vuln v0.2.0 h1:Dlz47lW0pvPHU7tnb10S8vbMn9GnV2B6eyT7Tem5XBI=
golang.org/x/vuln v0.2.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
golang.org/x/vuln v1.0.0 h1:tYLAU3jD9LQr98Y+3el06lWyGMCnvzw06PIWP3LIy7g=
golang.org/x/vuln v1.0.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -48,9 +49,8 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool)
// Config is the configuration structure for Default.
type Config struct {
// DialContext specifies the dial function for creating unencrypted TCP
// connections.
DialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// DialContext is used to create TCP connections to WHOIS servers.
DialContext aghnet.DialContextFunc
// ServerAddr is the address of the WHOIS server.
ServerAddr string
@@ -86,9 +86,8 @@ type Default struct {
// resolve the same IP.
cache gcache.Cache
// dialContext connects to a remote server resolving hostname using our own
// DNS server and unecrypted TCP connection.
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// dialContext is used to create TCP connections to WHOIS servers.
dialContext aghnet.DialContextFunc
// serverAddr is the address of the WHOIS server.
serverAddr string
@@ -215,7 +214,7 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
return nil, err
}
_ = conn.SetReadDeadline(time.Now().Add(w.timeout))
_ = conn.SetDeadline(time.Now().Add(w.timeout))
_, err = io.WriteString(conn, target+"\r\n")
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -310,7 +309,7 @@ func (w *Default) requestInfo(
kv, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: quering about %q: %s", ip, err)
log.Debug("whois: querying %q: %s", ip, err)
return nil, true
}

View File

@@ -113,20 +113,14 @@ func TestDefault_Process(t *testing.T) {
return copy(b, tc.data), io.EOF
},
OnWrite: func(b []byte) (n int, err error) {
return len(b), nil
},
OnClose: func() (err error) {
return nil
},
OnSetReadDeadline: func(t time.Time) (err error) {
return nil
},
OnWrite: func(b []byte) (n int, err error) { return len(b), nil },
OnClose: func() (err error) { return nil },
OnSetDeadline: func(t time.Time) (err error) { return nil },
}
w := whois.New(&whois.Config{
Timeout: 5 * time.Second,
DialContext: func(_ context.Context, _, addr string) (_ net.Conn, _ error) {
DialContext: func(_ context.Context, _, _ string) (_ net.Conn, _ error) {
hit = 0
return fakeConn, nil