all: sync with master; upd chlog
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
package aghio
|
||||
package aghio_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -31,7 +32,7 @@ func TestLimitReader(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := LimitReader(nil, tc.n)
|
||||
_, err := aghio.LimitReader(nil, tc.n)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
@@ -57,7 +58,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||
limit: 3,
|
||||
want: 0,
|
||||
}, {
|
||||
err: &LimitReachedError{
|
||||
err: &aghio.LimitReachedError{
|
||||
Limit: 0,
|
||||
},
|
||||
name: "limit_reached",
|
||||
@@ -74,7 +75,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||
lreader, err := LimitReader(readCloser, tc.limit)
|
||||
lreader, err := aghio.LimitReader(readCloser, tc.limit)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, lreader)
|
||||
|
||||
@@ -89,7 +90,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
||||
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
|
||||
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
|
||||
Limit: 0,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -141,9 +141,9 @@ type HostsRecord struct {
|
||||
Canonical string
|
||||
}
|
||||
|
||||
// equal returns true if all fields of rec are equal to field in other or they
|
||||
// Equal returns true if all fields of rec are equal to field in other or they
|
||||
// both are nil.
|
||||
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
|
||||
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
|
||||
if rec == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
@@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) {
|
||||
}
|
||||
|
||||
// hc.last is nil on the first refresh, so let that one through.
|
||||
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
|
||||
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) {
|
||||
log.Debug("%s: no changes detected", hostsContainerPrefix)
|
||||
|
||||
return nil
|
||||
|
||||
144
internal/aghnet/hostscontainer_internal_test.go
Normal file
144
internal/aghnet/hostscontainer_internal_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const nl = "\n"
|
||||
|
||||
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||
gsfs := fstest.MapFS{
|
||||
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
|
||||
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
|
||||
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
paths []string
|
||||
want []string
|
||||
}{{
|
||||
name: "no_paths",
|
||||
paths: nil,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "single_file",
|
||||
paths: []string{"dir_0/file_1"},
|
||||
want: []string{"dir_0/file_1"},
|
||||
}, {
|
||||
name: "several_files",
|
||||
paths: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
want: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
}, {
|
||||
name: "whole_dir",
|
||||
paths: []string{"dir_0"},
|
||||
want: []string{"dir_0/*"},
|
||||
}, {
|
||||
name: "file_and_dir",
|
||||
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
|
||||
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
|
||||
}, {
|
||||
name: "non-existing",
|
||||
paths: []string{path.Join("dir_0", "file_3")},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, patterns)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("bad_file", func(t *testing.T) {
|
||||
const errStat errors.Error = "bad file"
|
||||
|
||||
badFS := &fakefs.StatFS{
|
||||
OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") },
|
||||
OnStat: func(name string) (fi fs.FileInfo, err error) {
|
||||
return nil, errStat
|
||||
},
|
||||
}
|
||||
|
||||
_, err := pathsToPatterns(badFS, []string{""})
|
||||
assert.ErrorIs(t, err, errStat)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
ip := netutil.IPv4Localhost()
|
||||
ipStr := ip.String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
line string
|
||||
wantIP netip.Addr
|
||||
wantHosts []string
|
||||
}{{
|
||||
name: "simple",
|
||||
line: ipStr + ` hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "aliases",
|
||||
line: ipStr + ` hostname alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname", "alias"},
|
||||
}, {
|
||||
name: "invalid_line",
|
||||
line: ipStr,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "invalid_line_hostname",
|
||||
line: ipStr + ` # hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "commented_aliases",
|
||||
line: ipStr + ` hostname # alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "whole_comment",
|
||||
line: `# ` + ipStr + ` hostname`,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "partial_comment",
|
||||
line: ipStr + ` host#name`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"host"},
|
||||
}, {
|
||||
name: "empty",
|
||||
line: ``,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "bad_hosts",
|
||||
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"ok.host"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hp := hostsParser{}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, hosts := hp.parseLine(tc.line)
|
||||
assert.Equal(t, tc.wantIP, got)
|
||||
assert.Equal(t, tc.wantHosts, hosts)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package aghnet
|
||||
package aghnet_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -12,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -24,10 +23,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
nl = "\n"
|
||||
sp = " "
|
||||
)
|
||||
const nl = "\n"
|
||||
|
||||
func TestNewHostsContainer(t *testing.T) {
|
||||
const dirname = "dir"
|
||||
@@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
name: "one_file",
|
||||
paths: []string{p},
|
||||
}, {
|
||||
wantErr: ErrNoHostsPaths,
|
||||
wantErr: aghnet.ErrNoHostsPaths,
|
||||
name: "no_files",
|
||||
paths: []string{},
|
||||
}, {
|
||||
wantErr: ErrNoHostsPaths,
|
||||
wantErr: aghnet.ErrNoHostsPaths,
|
||||
name: "non-existent_file",
|
||||
paths: []string{path.Join(dirname, filename+"2")},
|
||||
}, {
|
||||
@@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
return eventsCh
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||
OnEvents: onEvents,
|
||||
OnAdd: onAdd,
|
||||
OnClose: func() (err error) { return nil },
|
||||
@@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
|
||||
t.Run("nil_fs", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
|
||||
_, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{
|
||||
// Those shouldn't panic.
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
@@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
|
||||
t.Run("nil_watcher", func(t *testing.T) {
|
||||
require.Panics(t, func() {
|
||||
_, _ = NewHostsContainer(0, testFS, nil, p)
|
||||
_, _ = aghnet.NewHostsContainer(0, testFS, nil, p)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
|
||||
hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p)
|
||||
require.ErrorIs(t, err, errOnAdd)
|
||||
|
||||
assert.Nil(t, hc)
|
||||
@@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(0, testFS, w, "dir")
|
||||
hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
||||
checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) {
|
||||
t.Helper()
|
||||
|
||||
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
|
||||
@@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rec)
|
||||
|
||||
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
|
||||
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, &HostsRecord{
|
||||
checkRefresh(t, &aghnet.HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
@@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
|
||||
checkRefresh(t, &HostsRecord{
|
||||
checkRefresh(t, &aghnet.HostsRecord{
|
||||
Aliases: stringutil.NewSet("alias"),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
@@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||
gsfs := fstest.MapFS{
|
||||
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
|
||||
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
|
||||
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
paths []string
|
||||
want []string
|
||||
}{{
|
||||
name: "no_paths",
|
||||
paths: nil,
|
||||
want: nil,
|
||||
}, {
|
||||
name: "single_file",
|
||||
paths: []string{"dir_0/file_1"},
|
||||
want: []string{"dir_0/file_1"},
|
||||
}, {
|
||||
name: "several_files",
|
||||
paths: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
want: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||
}, {
|
||||
name: "whole_dir",
|
||||
paths: []string{"dir_0"},
|
||||
want: []string{"dir_0/*"},
|
||||
}, {
|
||||
name: "file_and_dir",
|
||||
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
|
||||
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
|
||||
}, {
|
||||
name: "non-existing",
|
||||
paths: []string{path.Join("dir_0", "file_3")},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, patterns)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("bad_file", func(t *testing.T) {
|
||||
const errStat errors.Error = "bad file"
|
||||
|
||||
badFS := &aghtest.StatFS{
|
||||
OnStat: func(name string) (fs.FileInfo, error) {
|
||||
return nil, errStat
|
||||
},
|
||||
}
|
||||
|
||||
_, err := pathsToPatterns(badFS, []string{""})
|
||||
assert.ErrorIs(t, err, errStat)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHostsContainer_Translate(t *testing.T) {
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
@@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) {
|
||||
|
||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||
|
||||
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
||||
hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
@@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) {
|
||||
OnClose: func() (err error) { return nil },
|
||||
}
|
||||
|
||||
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
|
||||
hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
@@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||
ip := netutil.IPv4Localhost()
|
||||
ipStr := ip.String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
line string
|
||||
wantIP netip.Addr
|
||||
wantHosts []string
|
||||
}{{
|
||||
name: "simple",
|
||||
line: ipStr + ` hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "aliases",
|
||||
line: ipStr + ` hostname alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname", "alias"},
|
||||
}, {
|
||||
name: "invalid_line",
|
||||
line: ipStr,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "invalid_line_hostname",
|
||||
line: ipStr + ` # hostname`,
|
||||
wantIP: ip,
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "commented_aliases",
|
||||
line: ipStr + ` hostname # alias`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"hostname"},
|
||||
}, {
|
||||
name: "whole_comment",
|
||||
line: `# ` + ipStr + ` hostname`,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "partial_comment",
|
||||
line: ipStr + ` host#name`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"host"},
|
||||
}, {
|
||||
name: "empty",
|
||||
line: ``,
|
||||
wantIP: netip.Addr{},
|
||||
wantHosts: nil,
|
||||
}, {
|
||||
name: "bad_hosts",
|
||||
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
|
||||
wantIP: ip,
|
||||
wantHosts: []string{"ok.host"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
hp := hostsParser{}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, hosts := hp.parseLine(tc.line)
|
||||
assert.Equal(t, tc.wantIP, got)
|
||||
assert.Equal(t, tc.wantHosts, hosts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package aghnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,6 +16,10 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DialContextFunc is the semantic alias for dialing functions, such as
|
||||
// [http.Transport.DialContext].
|
||||
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
|
||||
// Variables and functions to substitute in tests.
|
||||
var (
|
||||
// aghosRunCommand is the function to run shell commands.
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
|
||||
Data: []byte(`nameserver 1.1.1.1`),
|
||||
},
|
||||
}
|
||||
panicFsys := &aghtest.FS{
|
||||
panicFsys := &fakefs.FS{
|
||||
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
|
||||
334
internal/aghnet/net_internal_test.go
Normal file
334
internal/aghnet/net_internal_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
|
||||
// package with fsys for tests ran under t.
|
||||
func substRootDirFS(t testing.TB, fsys fs.FS) {
|
||||
t.Helper()
|
||||
|
||||
prev := rootDirFS
|
||||
t.Cleanup(func() { rootDirFS = prev })
|
||||
rootDirFS = fsys
|
||||
}
|
||||
|
||||
// RunCmdFunc is the signature of aghos.RunCommand function.
|
||||
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
||||
|
||||
// substShell replaces the the aghos.RunCommand function used throughout the
|
||||
// package with rc for tests ran under t.
|
||||
func substShell(t testing.TB, rc RunCmdFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := aghosRunCommand
|
||||
t.Cleanup(func() { aghosRunCommand = prev })
|
||||
aghosRunCommand = rc
|
||||
}
|
||||
|
||||
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
||||
// execution result. It's only needed to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
||||
type mapShell map[string]struct {
|
||||
err error
|
||||
out string
|
||||
code int
|
||||
}
|
||||
|
||||
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
||||
// combination from cmd.
|
||||
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
||||
return mapShell{cmd: {code: code, out: out, err: err}}
|
||||
}
|
||||
|
||||
// RunCmd is a RunCmdFunc handled by s.
|
||||
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
||||
key := strings.Join(append([]string{cmd}, args...), " ")
|
||||
ret, ok := s[key]
|
||||
if !ok {
|
||||
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
||||
}
|
||||
|
||||
return ret.code, []byte(ret.out), ret.err
|
||||
}
|
||||
|
||||
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
|
||||
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
|
||||
|
||||
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
|
||||
// throughout the package with f for tests ran under t.
|
||||
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := netInterfaceAddrs
|
||||
t.Cleanup(func() { netInterfaceAddrs = prev })
|
||||
netInterfaceAddrs = f
|
||||
}
|
||||
|
||||
func TestGatewayIP(t *testing.T) {
|
||||
const ifaceName = "ifaceName"
|
||||
const cmd = "ip route show dev " + ifaceName
|
||||
|
||||
testCases := []struct {
|
||||
shell mapShell
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
||||
want: netip.MustParseAddr("1.2.3.4"),
|
||||
name: "success_v4",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
||||
want: netip.MustParseAddr("::ffff"),
|
||||
name: "success_v6",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
||||
want: netip.Addr{},
|
||||
name: "bad_output",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
||||
want: netip.Addr{},
|
||||
name: "err_runcmd",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 1, "", nil),
|
||||
want: netip.Addr{},
|
||||
name: "bad_code",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substShell(t, tc.shell.RunCmd)
|
||||
|
||||
assert.Equal(t, tc.want, GatewayIP(ifaceName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
|
||||
for _, iface := range ifaces {
|
||||
t.Run(iface.Name, func(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := InterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastFromIPNet(t *testing.T) {
|
||||
known4 := netip.MustParseAddr("192.168.0.1")
|
||||
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
|
||||
|
||||
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
|
||||
|
||||
testCases := []struct {
|
||||
pref netip.Prefix
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
pref: netip.PrefixFrom(known4, 0),
|
||||
want: fullBroadcast4,
|
||||
name: "full",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known4, 20),
|
||||
want: netip.MustParseAddr("192.168.15.255"),
|
||||
name: "full",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
|
||||
want: known6,
|
||||
name: "ipv6_no_mask",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
|
||||
want: known4,
|
||||
name: "ipv4_no_mask",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||
want: fullBroadcast4,
|
||||
name: "unspecified",
|
||||
}, {
|
||||
pref: netip.Prefix{},
|
||||
want: netip.Addr{},
|
||||
name: "invalid",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPort(t *testing.T) {
|
||||
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||
|
||||
t.Run("tcp_bound", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("tcp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("udp_bound", func(t *testing.T) {
|
||||
conn, err := net.ListenPacket("udp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
||||
|
||||
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("udp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("bad_network", func(t *testing.T) {
|
||||
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can_bind", func(t *testing.T) {
|
||||
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
addrs []net.Addr
|
||||
wantAddrs []string
|
||||
}{{
|
||||
name: "success",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{&net.IPNet{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
|
||||
}, &net.IPNet{
|
||||
IP: net.IP{4, 3, 2, 1},
|
||||
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
||||
}},
|
||||
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
||||
}, {
|
||||
name: "not_cidr",
|
||||
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
||||
addrs: []net.Addr{&net.IPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
}},
|
||||
wantAddrs: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{},
|
||||
wantAddrs: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
|
||||
|
||||
addrs, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.wantAddrs, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("internal_error", func(t *testing.T) {
|
||||
const errAddrs errors.Error = "can't get addresses"
|
||||
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
||||
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
||||
|
||||
_, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsAddrInUse(t *testing.T) {
|
||||
t.Run("addr_in_use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
|
||||
assert.True(t, IsAddrInUse(err))
|
||||
})
|
||||
|
||||
t.Run("another", func(t *testing.T) {
|
||||
const anotherErr errors.Error = "not addr in use"
|
||||
|
||||
assert.False(t, IsAddrInUse(anotherErr))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetInterface_MarshalJSON(t *testing.T) {
|
||||
const want = `{` +
|
||||
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
|
||||
`"flags":"up|multicast",` +
|
||||
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
|
||||
`"name":"iface0",` +
|
||||
`"mtu":1500` +
|
||||
`}` + "\n"
|
||||
|
||||
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
|
||||
require.True(t, ok)
|
||||
|
||||
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
|
||||
net4 := netip.PrefixFrom(ip4, 24)
|
||||
net6 := netip.PrefixFrom(ip6, 8)
|
||||
|
||||
iface := &NetInterface{
|
||||
Addresses: []netip.Addr{ip4, ip6},
|
||||
Subnets: []netip.Prefix{net4, net6},
|
||||
Name: "iface0",
|
||||
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
Flags: net.FlagUp | net.FlagMulticast,
|
||||
MTU: 1500,
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err := json.NewEncoder(b).Encode(iface)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, b.String())
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
package aghnet
|
||||
package aghnet_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -24,315 +14,3 @@ func TestMain(m *testing.M) {
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
|
||||
// package with fsys for tests ran under t.
|
||||
func substRootDirFS(t testing.TB, fsys fs.FS) {
|
||||
t.Helper()
|
||||
|
||||
prev := rootDirFS
|
||||
t.Cleanup(func() { rootDirFS = prev })
|
||||
rootDirFS = fsys
|
||||
}
|
||||
|
||||
// RunCmdFunc is the signature of aghos.RunCommand function.
|
||||
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
||||
|
||||
// substShell replaces the the aghos.RunCommand function used throughout the
|
||||
// package with rc for tests ran under t.
|
||||
func substShell(t testing.TB, rc RunCmdFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := aghosRunCommand
|
||||
t.Cleanup(func() { aghosRunCommand = prev })
|
||||
aghosRunCommand = rc
|
||||
}
|
||||
|
||||
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
||||
// execution result. It's only needed to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
||||
type mapShell map[string]struct {
|
||||
err error
|
||||
out string
|
||||
code int
|
||||
}
|
||||
|
||||
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
||||
// combination from cmd.
|
||||
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
||||
return mapShell{cmd: {code: code, out: out, err: err}}
|
||||
}
|
||||
|
||||
// RunCmd is a RunCmdFunc handled by s.
|
||||
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
||||
key := strings.Join(append([]string{cmd}, args...), " ")
|
||||
ret, ok := s[key]
|
||||
if !ok {
|
||||
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
||||
}
|
||||
|
||||
return ret.code, []byte(ret.out), ret.err
|
||||
}
|
||||
|
||||
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
|
||||
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
|
||||
|
||||
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
|
||||
// throughout the package with f for tests ran under t.
|
||||
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
|
||||
t.Helper()
|
||||
|
||||
prev := netInterfaceAddrs
|
||||
t.Cleanup(func() { netInterfaceAddrs = prev })
|
||||
netInterfaceAddrs = f
|
||||
}
|
||||
|
||||
func TestGatewayIP(t *testing.T) {
|
||||
const ifaceName = "ifaceName"
|
||||
const cmd = "ip route show dev " + ifaceName
|
||||
|
||||
testCases := []struct {
|
||||
shell mapShell
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
||||
want: netip.MustParseAddr("1.2.3.4"),
|
||||
name: "success_v4",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
||||
want: netip.MustParseAddr("::ffff"),
|
||||
name: "success_v6",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
||||
want: netip.Addr{},
|
||||
name: "bad_output",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
||||
want: netip.Addr{},
|
||||
name: "err_runcmd",
|
||||
}, {
|
||||
shell: theOnlyCmd(cmd, 1, "", nil),
|
||||
want: netip.Addr{},
|
||||
name: "bad_code",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substShell(t, tc.shell.RunCmd)
|
||||
|
||||
assert.Equal(t, tc.want, GatewayIP(ifaceName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
|
||||
for _, iface := range ifaces {
|
||||
t.Run(iface.Name, func(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := InterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastFromIPNet(t *testing.T) {
|
||||
known4 := netip.MustParseAddr("192.168.0.1")
|
||||
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
|
||||
|
||||
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
|
||||
|
||||
testCases := []struct {
|
||||
pref netip.Prefix
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
pref: netip.PrefixFrom(known4, 0),
|
||||
want: fullBroadcast4,
|
||||
name: "full",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known4, 20),
|
||||
want: netip.MustParseAddr("192.168.15.255"),
|
||||
name: "full",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
|
||||
want: known6,
|
||||
name: "ipv6_no_mask",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
|
||||
want: known4,
|
||||
name: "ipv4_no_mask",
|
||||
}, {
|
||||
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||
want: fullBroadcast4,
|
||||
name: "unspecified",
|
||||
}, {
|
||||
pref: netip.Prefix{},
|
||||
want: netip.Addr{},
|
||||
name: "invalid",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPort(t *testing.T) {
|
||||
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||
|
||||
t.Run("tcp_bound", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("tcp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("udp_bound", func(t *testing.T) {
|
||||
conn, err := net.ListenPacket("udp", laddr.String())
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
||||
|
||||
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
|
||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||
require.NotZero(t, ipp.Port())
|
||||
|
||||
err = CheckPort("udp", ipp)
|
||||
target := &net.OpError{}
|
||||
require.ErrorAs(t, err, &target)
|
||||
|
||||
assert.Equal(t, "listen", target.Op)
|
||||
})
|
||||
|
||||
t.Run("bad_network", func(t *testing.T) {
|
||||
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can_bind", func(t *testing.T) {
|
||||
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
addrs []net.Addr
|
||||
wantAddrs []string
|
||||
}{{
|
||||
name: "success",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{&net.IPNet{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
|
||||
}, &net.IPNet{
|
||||
IP: net.IP{4, 3, 2, 1},
|
||||
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
||||
}},
|
||||
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
||||
}, {
|
||||
name: "not_cidr",
|
||||
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
||||
addrs: []net.Addr{&net.IPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
}},
|
||||
wantAddrs: nil,
|
||||
}, {
|
||||
name: "empty",
|
||||
wantErrMsg: ``,
|
||||
addrs: []net.Addr{},
|
||||
wantAddrs: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
|
||||
|
||||
addrs, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.wantAddrs, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("internal_error", func(t *testing.T) {
|
||||
const errAddrs errors.Error = "can't get addresses"
|
||||
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
||||
|
||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
||||
|
||||
_, err := CollectAllIfacesAddrs()
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsAddrInUse(t *testing.T) {
|
||||
t.Run("addr_in_use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
|
||||
assert.True(t, IsAddrInUse(err))
|
||||
})
|
||||
|
||||
t.Run("another", func(t *testing.T) {
|
||||
const anotherErr errors.Error = "not addr in use"
|
||||
|
||||
assert.False(t, IsAddrInUse(anotherErr))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetInterface_MarshalJSON(t *testing.T) {
|
||||
const want = `{` +
|
||||
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
|
||||
`"flags":"up|multicast",` +
|
||||
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
|
||||
`"name":"iface0",` +
|
||||
`"mtu":1500` +
|
||||
`}` + "\n"
|
||||
|
||||
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
|
||||
require.True(t, ok)
|
||||
|
||||
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
|
||||
net4 := netip.PrefixFrom(ip4, 24)
|
||||
net6 := netip.PrefixFrom(ip6, 8)
|
||||
|
||||
iface := &NetInterface{
|
||||
Addresses: []netip.Addr{ip4, ip6},
|
||||
Subnets: []netip.Prefix{net4, net6},
|
||||
Name: "iface0",
|
||||
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||
Flags: net.FlagUp | net.FlagMulticast,
|
||||
MTU: 1500,
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
err := json.NewEncoder(b).Encode(iface)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, b.String())
|
||||
}
|
||||
|
||||
52
internal/aghrenameio/renameio.go
Normal file
52
internal/aghrenameio/renameio.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Package aghrenameio is a wrapper around package github.com/google/renameio/v2
|
||||
// that provides a similar stream-based API for both Unix and Windows systems.
|
||||
// While the Windows API is not technically atomic, it still provides a
|
||||
// consistent stream-based interface, and atomic renames of files do not seem to
|
||||
// be possible in all cases anyway.
|
||||
//
|
||||
// See https://github.com/google/renameio/issues/1.
|
||||
//
|
||||
// TODO(a.garipov): Consider moving to golibs/renameioutil once tried and
|
||||
// tested.
|
||||
package aghrenameio
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// PendingFile is the interface for pending temporary files.
|
||||
type PendingFile interface {
|
||||
// Cleanup closes the file, and removes it without performing the renaming.
|
||||
// To close and rename the file, use CloseReplace.
|
||||
Cleanup() (err error)
|
||||
|
||||
// CloseReplace closes the temporary file and replaces the destination file
|
||||
// with it, possibly atomically.
|
||||
//
|
||||
// This method is not safe for concurrent use by multiple goroutines.
|
||||
CloseReplace() (err error)
|
||||
|
||||
// Write writes len(b) bytes from b to the File. It returns the number of
|
||||
// bytes written and an error, if any. Write returns a non-nil error when n
|
||||
// != len(b).
|
||||
Write(b []byte) (n int, err error)
|
||||
}
|
||||
|
||||
// NewPendingFile is a wrapper around [renameio.NewPendingFile] on Unix systems
|
||||
// and [os.CreateTemp] on Windows.
|
||||
func NewPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
|
||||
return newPendingFile(filePath, mode)
|
||||
}
|
||||
|
||||
// WithDeferredCleanup is a helper that performs the necessary cleanups and
|
||||
// finalizations of the temporary files based on the returned error.
|
||||
func WithDeferredCleanup(returned error, file PendingFile) (err error) {
|
||||
// Make sure that any error returned from here is marked as a deferred one.
|
||||
if returned != nil {
|
||||
return errors.WithDeferred(returned, file.Cleanup())
|
||||
}
|
||||
|
||||
return errors.WithDeferred(nil, file.CloseReplace())
|
||||
}
|
||||
101
internal/aghrenameio/renameio_test.go
Normal file
101
internal/aghrenameio/renameio_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package aghrenameio_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testPerm is the common permission mode for tests.
|
||||
const testPerm fs.FileMode = 0o644
|
||||
|
||||
// Common file data for tests.
|
||||
var (
|
||||
initialData = []byte("initial data\n")
|
||||
newData = []byte("new data\n")
|
||||
)
|
||||
|
||||
func TestPendingFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetPath := newInitialFile(t)
|
||||
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = f.Write(newData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = f.CloseReplace()
|
||||
require.NoError(t, err)
|
||||
|
||||
gotData, err := os.ReadFile(targetPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, newData, gotData)
|
||||
}
|
||||
|
||||
// newInitialFile is a test helper that returns the path to the file containing
|
||||
// [initialData].
|
||||
func newInitialFile(t *testing.T) (targetPath string) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
targetPath = filepath.Join(dir, "target")
|
||||
|
||||
err := os.WriteFile(targetPath, initialData, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
return targetPath
|
||||
}
|
||||
|
||||
func TestWithDeferredCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testError errors.Error = "test error"
|
||||
|
||||
testCases := []struct {
|
||||
error error
|
||||
name string
|
||||
wantErrMsg string
|
||||
wantData []byte
|
||||
}{{
|
||||
name: "success",
|
||||
error: nil,
|
||||
wantErrMsg: "",
|
||||
wantData: newData,
|
||||
}, {
|
||||
name: "error",
|
||||
error: testError,
|
||||
wantErrMsg: testError.Error(),
|
||||
wantData: initialData,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetPath := newInitialFile(t)
|
||||
f, err := aghrenameio.NewPendingFile(targetPath, testPerm)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = f.Write(newData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = aghrenameio.WithDeferredCleanup(tc.error, f)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
gotData, err := os.ReadFile(targetPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantData, gotData)
|
||||
})
|
||||
}
|
||||
}
|
||||
48
internal/aghrenameio/renameio_unix.go
Normal file
48
internal/aghrenameio/renameio_unix.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build unix
|
||||
|
||||
package aghrenameio
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
||||
"github.com/google/renameio/v2"
|
||||
)
|
||||
|
||||
// pendingFile is a wrapper around [*renameio.PendingFile] making it an
|
||||
// [io.WriteCloser].
|
||||
type pendingFile struct {
|
||||
file *renameio.PendingFile
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ PendingFile = pendingFile{}
|
||||
|
||||
// Cleanup implements the [PendingFile] interface for pendingFile.
|
||||
func (f pendingFile) Cleanup() (err error) {
|
||||
return f.file.Cleanup()
|
||||
}
|
||||
|
||||
// CloseReplace implements the [PendingFile] interface for pendingFile.
|
||||
func (f pendingFile) CloseReplace() (err error) {
|
||||
return f.file.CloseAtomicallyReplace()
|
||||
}
|
||||
|
||||
// Write implements the [PendingFile] interface for pendingFile.
|
||||
func (f pendingFile) Write(b []byte) (n int, err error) {
|
||||
return f.file.Write(b)
|
||||
}
|
||||
|
||||
// NewPendingFile is a wrapper around [renameio.NewPendingFile].
|
||||
//
|
||||
// f.Close must be called to finish the renaming.
|
||||
func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
|
||||
file, err := renameio.NewPendingFile(filePath, renameio.WithPermissions(mode))
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pendingFile{
|
||||
file: file,
|
||||
}, nil
|
||||
}
|
||||
74
internal/aghrenameio/renameio_windows.go
Normal file
74
internal/aghrenameio/renameio_windows.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build windows
|
||||
|
||||
package aghrenameio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// pendingFile is a wrapper around [*os.File] calling [os.Rename] in its Close
|
||||
// method.
|
||||
type pendingFile struct {
|
||||
file *os.File
|
||||
targetPath string
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ PendingFile = (*pendingFile)(nil)
|
||||
|
||||
// Cleanup implements the [PendingFile] interface for *pendingFile.
|
||||
func (f *pendingFile) Cleanup() (err error) {
|
||||
closeErr := f.file.Close()
|
||||
err = os.Remove(f.file.Name())
|
||||
|
||||
// Put closeErr into the deferred error because that's where it is usually
|
||||
// expected.
|
||||
return errors.WithDeferred(err, closeErr)
|
||||
}
|
||||
|
||||
// CloseReplace implements the [PendingFile] interface for *pendingFile.
|
||||
func (f *pendingFile) CloseReplace() (err error) {
|
||||
err = f.file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing: %w", err)
|
||||
}
|
||||
|
||||
err = os.Rename(f.file.Name(), f.targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write implements the [PendingFile] interface for *pendingFile.
|
||||
func (f *pendingFile) Write(b []byte) (n int, err error) {
|
||||
return f.file.Write(b)
|
||||
}
|
||||
|
||||
// NewPendingFile is a wrapper around [os.CreateTemp].
|
||||
//
|
||||
// f.Close must be called to finish the renaming.
|
||||
func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error) {
|
||||
// Use the same directory as the file itself, because moves across
|
||||
// filesystems can be especially problematic.
|
||||
file, err := os.CreateTemp(filepath.Dir(filePath), "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening pending file: %w", err)
|
||||
}
|
||||
|
||||
err = file.Chmod(mode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing pending file: %w", err)
|
||||
}
|
||||
|
||||
return &pendingFile{
|
||||
file: file,
|
||||
targetPath: filePath,
|
||||
}, nil
|
||||
}
|
||||
@@ -2,7 +2,9 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -34,3 +36,10 @@ func ReplaceLogLevel(t testing.TB, l log.Level) {
|
||||
t.Cleanup(func() { log.SetLevel(prev) })
|
||||
log.SetLevel(l)
|
||||
}
|
||||
|
||||
// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host.
|
||||
func HostToIPs(host string) (ipv4, ipv6 net.IP) {
|
||||
hash := sha256.Sum256([]byte(host))
|
||||
|
||||
return net.IP(hash[:4]), net.IP(hash[4:20])
|
||||
}
|
||||
|
||||
@@ -2,11 +2,15 @@ package aghtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -15,67 +19,6 @@ import (
|
||||
//
|
||||
// Keep entities in this file in alphabetic order.
|
||||
|
||||
// Standard Library
|
||||
|
||||
// Package fs
|
||||
|
||||
// FS is a fake [fs.FS] implementation for tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.FS = (*FS)(nil)
|
||||
|
||||
// Open implements the [fs.FS] interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = (*GlobFS)(nil)
|
||||
|
||||
// GlobFS is a fake [fs.GlobFS] implementation for tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = (*StatFS)(nil)
|
||||
|
||||
// StatFS is a fake [fs.StatFS] implementation for tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// Package io
|
||||
|
||||
// Writer is a fake [io.Writer] implementation for tests.
|
||||
type Writer struct {
|
||||
OnWrite func(b []byte) (n int, err error)
|
||||
}
|
||||
|
||||
var _ io.Writer = (*Writer)(nil)
|
||||
|
||||
// Write implements the [io.Writer] interface for *Writer.
|
||||
func (w *Writer) Write(b []byte) (n int, err error) {
|
||||
return w.OnWrite(b)
|
||||
}
|
||||
|
||||
// Module adguard-home
|
||||
|
||||
// Package aghos
|
||||
@@ -135,6 +78,71 @@ func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
|
||||
return s.OnConfig()
|
||||
}
|
||||
|
||||
// Package client
|
||||
|
||||
// AddressProcessor is a fake [client.AddressProcessor] implementation for
|
||||
// tests.
|
||||
type AddressProcessor struct {
|
||||
OnProcess func(ip netip.Addr)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressProcessor = (*AddressProcessor)(nil)
|
||||
|
||||
// Process implements the [client.AddressProcessor] interface for
|
||||
// *AddressProcessor.
|
||||
func (p *AddressProcessor) Process(ip netip.Addr) {
|
||||
p.OnProcess(ip)
|
||||
}
|
||||
|
||||
// Close implements the [client.AddressProcessor] interface for
|
||||
// *AddressProcessor.
|
||||
func (p *AddressProcessor) Close() (err error) {
|
||||
return p.OnClose()
|
||||
}
|
||||
|
||||
// AddressUpdater is a fake [client.AddressUpdater] implementation for tests.
|
||||
type AddressUpdater struct {
|
||||
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressUpdater = (*AddressUpdater)(nil)
|
||||
|
||||
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||
// *AddressUpdater.
|
||||
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||
p.OnUpdateAddress(ip, host, info)
|
||||
}
|
||||
|
||||
// Package filtering
|
||||
|
||||
// Resolver is a fake [filtering.Resolver] implementation for tests.
|
||||
type Resolver struct {
|
||||
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
|
||||
}
|
||||
|
||||
// LookupIP implements the [filtering.Resolver] interface for *Resolver.
|
||||
func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) {
|
||||
return r.OnLookupIP(ctx, network, host)
|
||||
}
|
||||
|
||||
// Package rdns
|
||||
|
||||
// Exchanger is a fake [rdns.Exchanger] implementation for tests.
|
||||
type Exchanger struct {
|
||||
OnExchange func(ip netip.Addr) (host string, ttl time.Duration, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ rdns.Exchanger = (*Exchanger)(nil)
|
||||
|
||||
// Exchange implements [rdns.Exchanger] interface for *Exchanger.
|
||||
func (e *Exchanger) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
|
||||
return e.OnExchange(ip)
|
||||
}
|
||||
|
||||
// Module dnsproxy
|
||||
|
||||
// Package upstream
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
package aghtest_test
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
)
|
||||
|
||||
// Put interface checks that cause import cycles here.
|
||||
|
||||
// type check
|
||||
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TestResolver is a Resolver for tests.
|
||||
type TestResolver struct {
|
||||
counter int
|
||||
counterLock sync.Mutex
|
||||
}
|
||||
|
||||
// HostToIPs generates IPv4 and IPv6 from host.
|
||||
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
|
||||
hash := sha256.Sum256([]byte(host))
|
||||
|
||||
return net.IP(hash[:4]), net.IP(hash[4:20])
|
||||
}
|
||||
|
||||
// LookupIP implements Resolver interface for *testResolver. It returns the
|
||||
// slice of net.IP with IPv4 and IPv6 instances.
|
||||
func (r *TestResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ipv4, ipv6 := r.HostToIPs(host)
|
||||
addrs := []net.IP{ipv4, ipv6}
|
||||
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
r.counter++
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver interface for *testResolver. It returns the
|
||||
// slice of IPv4 and IPv6 instances converted to strings.
|
||||
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
|
||||
ipv4, ipv6 := r.HostToIPs(host)
|
||||
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
r.counter++
|
||||
|
||||
return []string{
|
||||
ipv4.String(),
|
||||
ipv6.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Counter returns the number of requests handled.
|
||||
func (r *TestResolver) Counter() int {
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
|
||||
return r.counter
|
||||
}
|
||||
294
internal/client/addrproc.go
Normal file
294
internal/client/addrproc.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// ErrClosed is returned from [AddressProcessor.Close] if it's closed more than
|
||||
// once.
|
||||
const ErrClosed errors.Error = "use of closed address processor"
|
||||
|
||||
// AddressProcessor is the interface for types that can process clients.
|
||||
type AddressProcessor interface {
|
||||
Process(ip netip.Addr)
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
// EmptyAddrProc is an [AddressProcessor] that does nothing.
|
||||
type EmptyAddrProc struct{}
|
||||
|
||||
// type check
|
||||
var _ AddressProcessor = EmptyAddrProc{}
|
||||
|
||||
// Process implements the [AddressProcessor] interface for EmptyAddrProc.
|
||||
func (EmptyAddrProc) Process(_ netip.Addr) {}
|
||||
|
||||
// Close implements the [AddressProcessor] interface for EmptyAddrProc.
|
||||
func (EmptyAddrProc) Close() (_ error) { return nil }
|
||||
|
||||
// DefaultAddrProcConfig is the configuration structure for address processors.
|
||||
type DefaultAddrProcConfig struct {
|
||||
// DialContext is used to create TCP connections to WHOIS servers.
|
||||
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
|
||||
DialContext aghnet.DialContextFunc
|
||||
|
||||
// Exchanger is used to perform rDNS queries. Exchanger must not be nil if
|
||||
// [DefaultAddrProcConfig.UseRDNS] is true.
|
||||
Exchanger rdns.Exchanger
|
||||
|
||||
// PrivateSubnets are used to determine if an incoming IP address is
|
||||
// private. It must not be nil.
|
||||
PrivateSubnets netutil.SubnetSet
|
||||
|
||||
// AddressUpdater is used to update the information about a client's IP
|
||||
// address. It must not be nil.
|
||||
AddressUpdater AddressUpdater
|
||||
|
||||
// InitialAddresses are the addresses that are queued for processing
|
||||
// immediately by [NewDefaultAddrProc].
|
||||
InitialAddresses []netip.Addr
|
||||
|
||||
// UseRDNS, if true, enables resolving of client IP addresses using reverse
|
||||
// DNS.
|
||||
UseRDNS bool
|
||||
|
||||
// UsePrivateRDNS, if true, enables resolving of private client IP addresses
|
||||
// using reverse DNS. See [DefaultAddrProcConfig.PrivateSubnets].
|
||||
UsePrivateRDNS bool
|
||||
|
||||
// UseWHOIS, if true, enables resolving of client IP addresses using WHOIS.
|
||||
UseWHOIS bool
|
||||
}
|
||||
|
||||
// AddressUpdater is the interface for storages of DNS clients that can update
|
||||
// information about them.
|
||||
//
|
||||
// TODO(a.garipov): Consider using the actual client storage once it is moved
|
||||
// into this package.
|
||||
type AddressUpdater interface {
|
||||
// UpdateAddress updates information about an IP address, setting host (if
|
||||
// not empty) and WHOIS information (if not nil).
|
||||
UpdateAddress(ip netip.Addr, host string, info *whois.Info)
|
||||
}
|
||||
|
||||
// DefaultAddrProc processes incoming client addresses with rDNS and WHOIS, if
|
||||
// configured, and updates that information in a client storage.
|
||||
type DefaultAddrProc struct {
|
||||
// clientIPsMu serializes closure of clientIPs and access to isClosed.
|
||||
clientIPsMu *sync.Mutex
|
||||
|
||||
// clientIPs is the channel queueing client processing tasks.
|
||||
clientIPs chan netip.Addr
|
||||
|
||||
// rdns is used to perform rDNS lookups of clients' IP addresses.
|
||||
rdns rdns.Interface
|
||||
|
||||
// whois is used to perform WHOIS lookups of clients' IP addresses.
|
||||
whois whois.Interface
|
||||
|
||||
// addrUpdater is used to update the information about a client's IP
|
||||
// address.
|
||||
addrUpdater AddressUpdater
|
||||
|
||||
// privateSubnets are used to determine if an incoming IP address is
|
||||
// private.
|
||||
privateSubnets netutil.SubnetSet
|
||||
|
||||
// isClosed is set to true once the address processor is closed.
|
||||
isClosed bool
|
||||
|
||||
// usePrivateRDNS, if true, enables resolving of private client IP addresses
|
||||
// using reverse DNS.
|
||||
usePrivateRDNS bool
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
|
||||
// processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
|
||||
// processing. It must be greater than zero.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for IP addresses cached by
|
||||
// rDNS and WHOIS.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
// NewDefaultAddrProc returns a new running client address processor. c must
|
||||
// not be nil.
|
||||
func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
p = &DefaultAddrProc{
|
||||
clientIPsMu: &sync.Mutex{},
|
||||
clientIPs: make(chan netip.Addr, defaultQueueSize),
|
||||
rdns: &rdns.Empty{},
|
||||
addrUpdater: c.AddressUpdater,
|
||||
whois: &whois.Empty{},
|
||||
privateSubnets: c.PrivateSubnets,
|
||||
usePrivateRDNS: c.UsePrivateRDNS,
|
||||
}
|
||||
|
||||
if c.UseRDNS {
|
||||
p.rdns = rdns.New(&rdns.Config{
|
||||
Exchanger: c.Exchanger,
|
||||
CacheSize: defaultCacheSize,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
}
|
||||
|
||||
if c.UseWHOIS {
|
||||
p.whois = newWHOIS(c.DialContext)
|
||||
}
|
||||
|
||||
go p.process()
|
||||
|
||||
for _, ip := range c.InitialAddresses {
|
||||
p.Process(ip)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// newWHOIS returns a whois.Interface instance using the given function for
|
||||
// dialing.
|
||||
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
const (
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
defaultTimeout = 5 * time.Second
|
||||
|
||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from a
|
||||
// net.Conn.
|
||||
defaultMaxConnReadSize = 64 * 1024
|
||||
|
||||
// defaultMaxRedirects is the maximum redirects count.
|
||||
defaultMaxRedirects = 5
|
||||
|
||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
||||
defaultMaxInfoLen = 250
|
||||
)
|
||||
|
||||
return whois.New(&whois.Config{
|
||||
DialContext: dialFunc,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
Timeout: defaultTimeout,
|
||||
CacheSize: defaultCacheSize,
|
||||
MaxConnReadSize: defaultMaxConnReadSize,
|
||||
MaxRedirects: defaultMaxRedirects,
|
||||
MaxInfoLen: defaultMaxInfoLen,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ AddressProcessor = (*DefaultAddrProc)(nil)
|
||||
|
||||
// Process implements the [AddressProcessor] interface for *DefaultAddrProc.
|
||||
func (p *DefaultAddrProc) Process(ip netip.Addr) {
|
||||
p.clientIPsMu.Lock()
|
||||
defer p.clientIPsMu.Unlock()
|
||||
|
||||
if p.isClosed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.clientIPs <- ip:
|
||||
// Go on.
|
||||
default:
|
||||
log.Debug("clients: ip channel is full; len: %d", len(p.clientIPs))
|
||||
}
|
||||
}
|
||||
|
||||
// process processes the incoming client IP-address information. It is intended
|
||||
// to be used as a goroutine. Once clientIPs is closed, process exits.
|
||||
func (p *DefaultAddrProc) process() {
|
||||
defer log.OnPanic("addrProcessor.process")
|
||||
|
||||
log.Info("clients: processing addresses")
|
||||
|
||||
for ip := range p.clientIPs {
|
||||
host := p.processRDNS(ip)
|
||||
info := p.processWHOIS(ip)
|
||||
|
||||
p.addrUpdater.UpdateAddress(ip, host, info)
|
||||
}
|
||||
|
||||
log.Info("clients: finished processing addresses")
|
||||
}
|
||||
|
||||
// processRDNS resolves the clients' IP addresses using reverse DNS. host is
|
||||
// empty if there were errors or if the information hasn't changed.
|
||||
func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with rdns", ip)
|
||||
defer func() {
|
||||
log.Debug("clients: finished processing %s with rdns in %s", ip, time.Since(start))
|
||||
}()
|
||||
|
||||
ok := p.shouldResolve(ip)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
host, changed := p.rdns.Process(ip)
|
||||
if !changed {
|
||||
host = ""
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// shouldResolve returns false if ip is a loopback address, or ip is private and
|
||||
// resolving of private addresses is disabled.
|
||||
func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
|
||||
return !ip.IsLoopback() &&
|
||||
(p.usePrivateRDNS || !p.privateSubnets.Contains(ip.AsSlice()))
|
||||
}
|
||||
|
||||
// processWHOIS looks up the information about clients' IP addresses in the
|
||||
// WHOIS databases. info is nil if there were errors or if the information
|
||||
// hasn't changed.
|
||||
func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with whois", ip)
|
||||
defer func() {
|
||||
log.Debug("clients: finished processing %s with whois in %s", ip, time.Since(start))
|
||||
}()
|
||||
|
||||
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
|
||||
// context.
|
||||
info, changed := p.whois.Process(context.Background(), ip)
|
||||
if !changed {
|
||||
info = nil
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// Close implements the [AddressProcessor] interface for *DefaultAddrProc.
|
||||
func (p *DefaultAddrProc) Close() (err error) {
|
||||
p.clientIPsMu.Lock()
|
||||
defer p.clientIPsMu.Unlock()
|
||||
|
||||
if p.isClosed {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
close(p.clientIPs)
|
||||
p.isClosed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
259
internal/client/addrproc_test.go
Normal file
259
internal/client/addrproc_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmptyAddrProc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := client.EmptyAddrProc{}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
p.Process(testIP)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
err := p.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateIP := netip.MustParseAddr("192.168.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
rdnsErr error
|
||||
ip netip.Addr
|
||||
name string
|
||||
host string
|
||||
usePrivate bool
|
||||
wantUpd bool
|
||||
}{{
|
||||
rdnsErr: nil,
|
||||
ip: testIP,
|
||||
name: "success",
|
||||
host: testHost,
|
||||
usePrivate: false,
|
||||
wantUpd: true,
|
||||
}, {
|
||||
rdnsErr: nil,
|
||||
ip: testIP,
|
||||
name: "no_host",
|
||||
host: "",
|
||||
usePrivate: false,
|
||||
wantUpd: false,
|
||||
}, {
|
||||
rdnsErr: nil,
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
name: "localhost",
|
||||
host: "",
|
||||
usePrivate: false,
|
||||
wantUpd: false,
|
||||
}, {
|
||||
rdnsErr: nil,
|
||||
ip: privateIP,
|
||||
name: "private_ignored",
|
||||
host: "",
|
||||
usePrivate: false,
|
||||
wantUpd: false,
|
||||
}, {
|
||||
rdnsErr: nil,
|
||||
ip: privateIP,
|
||||
name: "private_processed",
|
||||
host: "private.example",
|
||||
usePrivate: true,
|
||||
wantUpd: true,
|
||||
}, {
|
||||
rdnsErr: errors.Error("rdns error"),
|
||||
ip: testIP,
|
||||
name: "rdns_error",
|
||||
host: "",
|
||||
usePrivate: false,
|
||||
wantUpd: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updIPCh := make(chan netip.Addr, 1)
|
||||
updHostCh := make(chan string, 1)
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
Exchanger: &aghtest.Exchanger{
|
||||
OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) {
|
||||
return tc.host, 0, tc.rdnsErr
|
||||
},
|
||||
},
|
||||
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
AddressUpdater: &aghtest.AddressUpdater{
|
||||
OnUpdateAddress: newOnUpdateAddress(tc.wantUpd, updIPCh, updHostCh, updInfoCh),
|
||||
},
|
||||
UseRDNS: true,
|
||||
UsePrivateRDNS: tc.usePrivate,
|
||||
UseWHOIS: false,
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, p.Close)
|
||||
|
||||
p.Process(tc.ip)
|
||||
|
||||
if !tc.wantUpd {
|
||||
return
|
||||
}
|
||||
|
||||
gotIP, _ := testutil.RequireReceive(t, updIPCh, testTimeout)
|
||||
assert.Equal(t, tc.ip, gotIP)
|
||||
|
||||
gotHost, _ := testutil.RequireReceive(t, updHostCh, testTimeout)
|
||||
assert.Equal(t, tc.host, gotHost)
|
||||
|
||||
gotInfo, _ := testutil.RequireReceive(t, updInfoCh, testTimeout)
|
||||
assert.Nil(t, gotInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newOnUpdateAddress is a test helper that returns a new OnUpdateAddress
|
||||
// callback using the provided channels if an update is expected and panicking
|
||||
// otherwise.
|
||||
func newOnUpdateAddress(
|
||||
want bool,
|
||||
ips chan<- netip.Addr,
|
||||
hosts chan<- string,
|
||||
infos chan<- *whois.Info,
|
||||
) (f func(ip netip.Addr, host string, info *whois.Info)) {
|
||||
return func(ip netip.Addr, host string, info *whois.Info) {
|
||||
if !want {
|
||||
panic("got unexpected update")
|
||||
}
|
||||
|
||||
ips <- ip
|
||||
hosts <- host
|
||||
infos <- info
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
wantInfo *whois.Info
|
||||
exchErr error
|
||||
name string
|
||||
wantUpd bool
|
||||
}{{
|
||||
wantInfo: &whois.Info{
|
||||
City: testWHOISCity,
|
||||
},
|
||||
exchErr: nil,
|
||||
name: "success",
|
||||
wantUpd: true,
|
||||
}, {
|
||||
wantInfo: nil,
|
||||
exchErr: nil,
|
||||
name: "no_info",
|
||||
wantUpd: false,
|
||||
}, {
|
||||
wantInfo: nil,
|
||||
exchErr: errors.Error("whois error"),
|
||||
name: "whois_error",
|
||||
wantUpd: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whoisConn := &fakenet.Conn{
|
||||
OnClose: func() (err error) { return nil },
|
||||
OnRead: func(b []byte) (n int, err error) {
|
||||
if tc.wantInfo == nil {
|
||||
return 0, tc.exchErr
|
||||
}
|
||||
|
||||
data := "city: " + tc.wantInfo.City + "\n"
|
||||
copy(b, data)
|
||||
|
||||
return len(data), io.EOF
|
||||
},
|
||||
OnSetDeadline: func(_ time.Time) (err error) { return nil },
|
||||
OnWrite: func(b []byte) (n int, err error) { return len(b), nil },
|
||||
}
|
||||
|
||||
updIPCh := make(chan netip.Addr, 1)
|
||||
updHostCh := make(chan string, 1)
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
return whoisConn, nil
|
||||
},
|
||||
Exchanger: &aghtest.Exchanger{
|
||||
OnExchange: func(_ netip.Addr) (_ string, _ time.Duration, _ error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
},
|
||||
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
AddressUpdater: &aghtest.AddressUpdater{
|
||||
OnUpdateAddress: newOnUpdateAddress(tc.wantUpd, updIPCh, updHostCh, updInfoCh),
|
||||
},
|
||||
UseRDNS: false,
|
||||
UsePrivateRDNS: false,
|
||||
UseWHOIS: true,
|
||||
})
|
||||
testutil.CleanupAndRequireSuccess(t, p.Close)
|
||||
|
||||
p.Process(testIP)
|
||||
|
||||
if !tc.wantUpd {
|
||||
return
|
||||
}
|
||||
|
||||
gotIP, _ := testutil.RequireReceive(t, updIPCh, testTimeout)
|
||||
assert.Equal(t, testIP, gotIP)
|
||||
|
||||
gotHost, _ := testutil.RequireReceive(t, updHostCh, testTimeout)
|
||||
assert.Empty(t, gotHost)
|
||||
|
||||
gotInfo, _ := testutil.RequireReceive(t, updInfoCh, testTimeout)
|
||||
assert.Equal(t, tc.wantInfo, gotInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAddrProc_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{})
|
||||
|
||||
err := p.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = p.Close()
|
||||
assert.ErrorIs(t, err, client.ErrClosed)
|
||||
}
|
||||
5
internal/client/client.go
Normal file
5
internal/client/client.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package client contains types and logic dealing with AdGuard Home's DNS
|
||||
// clients.
|
||||
//
|
||||
// TODO(a.garipov): Expand.
|
||||
package client
|
||||
25
internal/client/client_test.go
Normal file
25
internal/client/client_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testHost is the common hostname for tests.
|
||||
const testHost = "client.example"
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testWHOISCity is the common city for tests.
|
||||
const testWHOISCity = "Brussels"
|
||||
|
||||
// testIP is the common IP address for tests.
|
||||
var testIP = netip.MustParseAddr("1.2.3.4")
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/renameio/maybe"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -270,7 +271,13 @@ type ServerConfig struct {
|
||||
UDPListenAddrs []*net.UDPAddr // UDP listen address
|
||||
TCPListenAddrs []*net.TCPAddr // TCP listen address
|
||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||
OnDNSRequest func(d *proxy.DNSContext)
|
||||
|
||||
// AddrProcConf defines the configuration for the client IP processor.
|
||||
// If nil, [client.EmptyAddrProc] is used.
|
||||
//
|
||||
// TODO(a.garipov): The use of [client.EmptyAddrProc] is a crutch for tests.
|
||||
// Remove that.
|
||||
AddrProcConf *client.DefaultAddrProcConfig
|
||||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
@@ -298,9 +305,6 @@ type ServerConfig struct {
|
||||
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64.
|
||||
DNS64Prefixes []netip.Prefix
|
||||
|
||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
||||
ResolveClients bool
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
UsePrivateRDNS bool
|
||||
@@ -340,6 +344,7 @@ func (s *Server) createProxyConfig() (conf proxy.Config, err error) {
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: int(srvConf.MaxGoroutines),
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
|
||||
57
internal/dnsforward/dialcontext.go
Normal file
57
internal/dnsforward/dialcontext.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
||||
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
// TODO(a.garipov): Consider making configurable.
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
addrs, err := s.Resolve(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: resolving %q: %v", host, addrs)
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("no addresses for host %q", host)
|
||||
}
|
||||
|
||||
var dialErrs []error
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
conn, err = dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
dialErrs = append(dialErrs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use errors.Join in Go 1.20.
|
||||
return nil, errors.List(fmt.Sprintf("dialing %q", addr), dialErrs...)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -99,8 +100,17 @@ type Server struct {
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
privateNets netutil.SubnetSet
|
||||
ipset ipsetCtx
|
||||
privateNets netutil.SubnetSet
|
||||
|
||||
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// localResolvers is a DNS proxy instance used to resolve PTR records for
|
||||
// addresses considered private as per the [privateNets].
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
|
||||
@@ -170,6 +180,9 @@ const (
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
//
|
||||
// TODO(a.garipov): How many constructors and initializers does this thing have?
|
||||
// Refactor!
|
||||
func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
var localDomainSuffix string
|
||||
if p.LocalDomain == "" {
|
||||
@@ -257,14 +270,25 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) {
|
||||
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
|
||||
}
|
||||
|
||||
// RDNSSettings returns the copy of actual RDNS configuration.
|
||||
func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) {
|
||||
// LocalPTRResolvers returns the current local PTR resolver configuration.
|
||||
func (s *Server) LocalPTRResolvers() (localPTRResolvers []string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers),
|
||||
s.conf.ResolveClients,
|
||||
s.conf.UsePrivateRDNS
|
||||
return stringutil.CloneSlice(s.conf.LocalPTRResolvers)
|
||||
}
|
||||
|
||||
// AddrProcConfig returns the current address processing configuration. Only
|
||||
// fields c.UsePrivateRDNS, c.UseRDNS, and c.UseWHOIS are filled.
|
||||
func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return &client.DefaultAddrProcConfig{
|
||||
UsePrivateRDNS: s.conf.UsePrivateRDNS,
|
||||
UseRDNS: s.conf.AddrProcConf.UseRDNS,
|
||||
UseWHOIS: s.conf.AddrProcConf.UseWHOIS,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve - get IP addresses by host name from an upstream server.
|
||||
@@ -292,17 +316,13 @@ const (
|
||||
var _ rdns.Exchanger = (*Server)(nil)
|
||||
|
||||
// Exchange implements the [rdns.Exchanger] interface for *Server.
|
||||
func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.conf.ResolveClients {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
return "", 0, fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
arpa = dns.Fqdn(arpa)
|
||||
@@ -318,16 +338,17 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
ctx := &proxy.DNSContext{
|
||||
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: "udp",
|
||||
Req: req,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
var resolver *proxy.Proxy
|
||||
if s.isPrivateIP(ip) {
|
||||
if s.privateNets.Contains(ip.AsSlice()) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
return "", 0, nil
|
||||
}
|
||||
|
||||
resolver = s.localResolvers
|
||||
@@ -336,53 +357,48 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
|
||||
resolver = s.internalProxy
|
||||
}
|
||||
|
||||
if err = resolver.Resolve(ctx); err != nil {
|
||||
return "", err
|
||||
if err = resolver.Resolve(dctx); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return hostFromPTR(ctx.Res)
|
||||
return hostFromPTR(dctx.Res)
|
||||
}
|
||||
|
||||
// hostFromPTR returns domain name from the PTR response or error.
|
||||
func hostFromPTR(resp *dns.Msg) (host string, err error) {
|
||||
func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
|
||||
// Distinguish between NODATA response and a failed request.
|
||||
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
|
||||
return "", fmt.Errorf(
|
||||
return "", 0, fmt.Errorf(
|
||||
"received %s response: %w",
|
||||
dns.RcodeToString[resp.Rcode],
|
||||
ErrRDNSFailed,
|
||||
)
|
||||
}
|
||||
|
||||
var ttlSec uint32
|
||||
|
||||
for _, ans := range resp.Answer {
|
||||
ptr, ok := ans.(*dns.PTR)
|
||||
if ok {
|
||||
return strings.TrimSuffix(ptr.Ptr, "."), nil
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptr.Hdr.Ttl > ttlSec {
|
||||
host = ptr.Ptr
|
||||
ttlSec = ptr.Hdr.Ttl
|
||||
}
|
||||
}
|
||||
|
||||
return "", ErrRDNSNoData
|
||||
}
|
||||
if host != "" {
|
||||
// NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter
|
||||
// case.
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
ttl = time.Duration(ttlSec) * time.Second
|
||||
|
||||
// isPrivateIP returns true if the ip is private.
|
||||
func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) {
|
||||
return s.privateNets.Contains(ip.AsSlice())
|
||||
}
|
||||
|
||||
// ShouldResolveClient returns false if ip is a loopback address, or ip is
|
||||
// private and resolving of private addresses is disabled.
|
||||
func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) {
|
||||
if ip.IsLoopback() {
|
||||
return false
|
||||
return host, ttl, nil
|
||||
}
|
||||
|
||||
isPrivate := s.isPrivateIP(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
return s.conf.ResolveClients &&
|
||||
(s.conf.UsePrivateRDNS || !isPrivate)
|
||||
return "", 0, ErrRDNSNoData
|
||||
}
|
||||
|
||||
// Start starts the DNS server.
|
||||
@@ -457,23 +473,27 @@ func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error
|
||||
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
|
||||
}
|
||||
|
||||
// setupResolvers initializes the resolvers for local addresses. For internal
|
||||
// use only.
|
||||
func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. For
|
||||
// internal use only.
|
||||
func (s *Server) setupLocalResolvers() (err error) {
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
if len(localAddrs) == 0 {
|
||||
localAddrs = s.sysResolvers.Get()
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
resolvers = s.sysResolvers.Get()
|
||||
bootstraps = nil
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
}
|
||||
|
||||
localAddrs, err = s.filterOurDNSAddrs(localAddrs)
|
||||
resolvers, err = s.filterOurDNSAddrs(resolvers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
|
||||
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
@@ -486,10 +506,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: upsConfig,
|
||||
UpstreamConfig: uc,
|
||||
},
|
||||
}
|
||||
|
||||
if s.conf.UsePrivateRDNS &&
|
||||
// Only set the upstream config if there are any upstreams. It's safe
|
||||
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
|
||||
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -539,25 +566,48 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
return fmt.Errorf("preparing access: %w", err)
|
||||
}
|
||||
|
||||
s.registerHandlers()
|
||||
|
||||
// Set the proxy here because [setupLocalResolvers] sets its values.
|
||||
//
|
||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||
err = s.setupResolvers(s.conf.LocalPTRResolvers)
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
|
||||
err = s.setupLocalResolvers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up resolvers: %w", err)
|
||||
}
|
||||
|
||||
if s.conf.UsePrivateRDNS {
|
||||
proxyConfig.PrivateRDNSUpstreamConfig = s.localResolvers.UpstreamConfig
|
||||
}
|
||||
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
|
||||
s.recDetector.clear()
|
||||
|
||||
s.setupAddrProc()
|
||||
|
||||
s.registerHandlers()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupAddrProc initializes the address processor. For internal use only.
|
||||
func (s *Server) setupAddrProc() {
|
||||
// TODO(a.garipov): This is a crutch for tests; remove.
|
||||
if s.conf.AddrProcConf == nil {
|
||||
s.conf.AddrProcConf = &client.DefaultAddrProcConfig{}
|
||||
}
|
||||
if s.conf.AddrProcConf.AddressUpdater == nil {
|
||||
s.addrProc = client.EmptyAddrProc{}
|
||||
} else {
|
||||
c := s.conf.AddrProcConf
|
||||
c.DialContext = s.DialContext
|
||||
c.PrivateSubnets = s.privateNets
|
||||
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
||||
s.addrProc = client.NewDefaultAddrProc(s.conf.AddrProcConf)
|
||||
|
||||
// Clear the initial addresses to not resolve them again.
|
||||
//
|
||||
// TODO(a.garipov): Consider ways of removing this once more client
|
||||
// logic is moved to package client.
|
||||
c.InitialAddresses = nil
|
||||
}
|
||||
}
|
||||
|
||||
// validateBlockingMode returns an error if the blocking mode data aren't valid.
|
||||
func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) {
|
||||
switch mode {
|
||||
@@ -696,6 +746,11 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||
if conf == nil {
|
||||
conf = &s.conf
|
||||
} else {
|
||||
closeErr := s.addrProc.Close()
|
||||
if closeErr != nil {
|
||||
log.Error("dnsforward: closing address processor: %s", closeErr)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Prepare(conf)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -39,11 +40,29 @@ func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testQuestionTarget is the common question target for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
const testQuestionTarget = "target.example"
|
||||
|
||||
const (
|
||||
tlsServerName = "testdns.adguard.com"
|
||||
testMessagesCount = 10
|
||||
)
|
||||
|
||||
// testClientAddr is the common net.Addr for tests.
|
||||
//
|
||||
// TODO(a.garipov): Use more.
|
||||
var testClientAddr net.Addr = &net.TCPAddr{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
@@ -53,6 +72,13 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
}
|
||||
|
||||
// packageUpstreamVariableMu is used to serialize access to the package-level
|
||||
// variables of package upstream.
|
||||
//
|
||||
// TODO(s.chzhen): Move these parameters to upstream options and remove this
|
||||
// crutch.
|
||||
var packageUpstreamVariableMu = &sync.Mutex{}
|
||||
|
||||
func createTestServer(
|
||||
t *testing.T,
|
||||
filterConf *filtering.Config,
|
||||
@@ -61,6 +87,9 @@ func createTestServer(
|
||||
) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
packageUpstreamVariableMu.Lock()
|
||||
defer packageUpstreamVariableMu.Unlock()
|
||||
|
||||
rules := `||nxdomain.example.org
|
||||
||NULL.example.org^
|
||||
127.0.0.1 host.example.org
|
||||
@@ -307,11 +336,9 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_timeout(t *testing.T) {
|
||||
const timeout time.Duration = time.Second
|
||||
|
||||
t.Run("custom", func(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
UpstreamTimeout: timeout,
|
||||
UpstreamTimeout: testTimeout,
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockingMode: BlockingModeDefault,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -324,7 +351,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
err = s.Prepare(srvConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, timeout, s.conf.UpstreamTimeout)
|
||||
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
@@ -441,7 +468,14 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
},
|
||||
}
|
||||
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
@@ -480,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
testCases := []struct {
|
||||
host string
|
||||
@@ -545,7 +579,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
Timeout: testTimeout,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
assert.NoErrorf(t, err, "got a response to an invalid query")
|
||||
@@ -928,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
Upstream: aghtest.NewBlockUpstream(hostname, true),
|
||||
})
|
||||
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
ans4, _ := aghtest.HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
@@ -1266,25 +1300,57 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
|
||||
// copy of first answer record with doubled TTL.
|
||||
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(msg.Answer) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
rec := msg.Answer[0]
|
||||
ptr, ok := rec.(*dns.PTR)
|
||||
if !ok {
|
||||
return msg
|
||||
}
|
||||
|
||||
clone := *ptr
|
||||
clone.Hdr.Ttl *= 2
|
||||
msg.Answer = append(msg.Answer, &clone)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
twosHost = "two.two.two.two"
|
||||
localDomainHost = "local.domain"
|
||||
|
||||
defaultTTL = time.Second * 60
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = netip.MustParseAddr("1.1.1.1")
|
||||
twosIP = netip.MustParseAddr("2.2.2.2")
|
||||
localIP = netip.MustParseAddr("192.168.1.1")
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
|
||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
@@ -1320,53 +1386,65 @@ func TestServer_Exchange(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
req netip.Addr
|
||||
wantErr error
|
||||
locUpstream upstream.Upstream
|
||||
req netip.Addr
|
||||
name string
|
||||
want string
|
||||
wantTTL time.Duration
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: onesIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
wantTTL: defaultTTL,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "empty_answer_error",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: locUpstream,
|
||||
req: netip.MustParseAddr("192.168.1.2"),
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "invalid_answer",
|
||||
want: "",
|
||||
wantErr: ErrRDNSNoData,
|
||||
locUpstream: nonPtrUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "refused",
|
||||
want: "",
|
||||
wantErr: ErrRDNSFailed,
|
||||
locUpstream: refusingUpstream,
|
||||
req: localIP,
|
||||
wantTTL: 0,
|
||||
}, {
|
||||
name: "longest_ttl",
|
||||
want: twosHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: twosIP,
|
||||
wantTTL: defaultTTL * 2,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -1380,73 +1458,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, eerr := srv.Exchange(tc.req)
|
||||
host, ttl, eerr := srv.Exchange(tc.req)
|
||||
|
||||
require.ErrorIs(t, eerr, tc.wantErr)
|
||||
assert.Equal(t, tc.want, host)
|
||||
assert.Equal(t, tc.wantTTL, ttl)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("resolving_disabled", func(t *testing.T) {
|
||||
srv.conf.UsePrivateRDNS = false
|
||||
|
||||
host, eerr := srv.Exchange(localIP)
|
||||
host, _, eerr := srv.Exchange(localIP)
|
||||
|
||||
require.NoError(t, eerr)
|
||||
assert.Empty(t, host)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_ShouldResolveClient(t *testing.T) {
|
||||
srv := &Server{
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
ip netip.Addr
|
||||
want require.BoolAssertionFunc
|
||||
name string
|
||||
resolve bool
|
||||
usePrivate bool
|
||||
}{{
|
||||
name: "default",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "no_rdns",
|
||||
ip: netip.MustParseAddr("1.1.1.1"),
|
||||
want: require.False,
|
||||
resolve: false,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "loopback",
|
||||
ip: netip.MustParseAddr("127.0.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.True,
|
||||
resolve: true,
|
||||
usePrivate: true,
|
||||
}, {
|
||||
name: "private_no_resolve",
|
||||
ip: netip.MustParseAddr("192.168.0.1"),
|
||||
want: require.False,
|
||||
resolve: true,
|
||||
usePrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv.conf.ResolveClients = tc.resolve
|
||||
srv.conf.UsePrivateRDNS = tc.usePrivate
|
||||
|
||||
ok := srv.ShouldResolveClient(tc.ip)
|
||||
tc.want(t, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,10 +50,10 @@ func (s *Server) beforeRequestHandler(
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||
// the client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.Settings()
|
||||
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||
// client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||
setts = s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||
|
||||
@@ -124,7 +124,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
cacheMinTTL := s.conf.CacheMinTTL
|
||||
cacheMaxTTL := s.conf.CacheMaxTTL
|
||||
cacheOptimistic := s.conf.CacheOptimistic
|
||||
resolveClients := s.conf.ResolveClients
|
||||
resolveClients := s.conf.AddrProcConf.UseRDNS
|
||||
usePrivateRDNS := s.conf.UsePrivateRDNS
|
||||
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
|
||||
|
||||
@@ -314,8 +314,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled)
|
||||
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
|
||||
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)
|
||||
setIfNotNil(&s.conf.ResolveClients, dc.ResolveClients)
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS)
|
||||
|
||||
return s.setConfigRestartable(dc)
|
||||
}
|
||||
@@ -335,6 +333,9 @@ func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {
|
||||
// setConfigRestartable sets the parameters which trigger a restart.
|
||||
// shouldRestart is true if the server should be restarted to apply changes.
|
||||
// s.serverLock is expected to be locked.
|
||||
//
|
||||
// TODO(a.garipov): Some of these could probably be updated without a restart.
|
||||
// Inspect and consider refactoring.
|
||||
func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
for _, hasSet := range []bool{
|
||||
setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams),
|
||||
@@ -347,6 +348,8 @@ func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) {
|
||||
setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL),
|
||||
setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL),
|
||||
setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic),
|
||||
setIfNotNil(&s.conf.AddrProcConf.UseRDNS, dc.ResolveClients),
|
||||
setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS),
|
||||
} {
|
||||
shouldRestart = shouldRestart || hasSet
|
||||
if shouldRestart {
|
||||
|
||||
@@ -30,6 +30,7 @@ type dnsContext struct {
|
||||
setts *filtering.Settings
|
||||
|
||||
result *filtering.Result
|
||||
|
||||
// origResp is the response received from upstream. It is set when the
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
@@ -48,13 +49,13 @@ type dnsContext struct {
|
||||
// clientID is the ClientID from DoH, DoQ, or DoT, if provided.
|
||||
clientID string
|
||||
|
||||
// startTime is the time at which the processing of the request has started.
|
||||
startTime time.Time
|
||||
|
||||
// origQuestion is the question received from the client. It is set
|
||||
// when the request is modified by rewrites.
|
||||
origQuestion dns.Question
|
||||
|
||||
// startTime is the time at which the processing of the request has started.
|
||||
startTime time.Time
|
||||
|
||||
// protectionEnabled shows if the filtering is enabled, and if the
|
||||
// server's DNS filter is ready.
|
||||
protectionEnabled bool
|
||||
@@ -160,6 +161,22 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||
// own DoH server.
|
||||
//
|
||||
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
||||
const mozillaFQDN = "use-application-dns.net."
|
||||
|
||||
// healthcheckFQDN is a reserved domain-name used for healthchecking.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||
// grant requests to register test names in the normal way to any person or
|
||||
// entity, making domain names under the .test TLD free to use in internal
|
||||
// purposes.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||
const healthcheckFQDN = "healthcheck.adguardhome.test."
|
||||
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
//
|
||||
@@ -169,6 +186,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
s.processClientIP(pctx.Addr)
|
||||
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
|
||||
@@ -177,28 +196,13 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
if s.conf.OnDNSRequest != nil {
|
||||
s.conf.OnDNSRequest(pctx)
|
||||
}
|
||||
|
||||
// Disable Mozilla DoH.
|
||||
//
|
||||
// See https://support.mozilla.org/en-US/kb/canary-domain-use-application-dnsnet.
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == "use-application-dns.net." {
|
||||
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
// Handle a reserved domain healthcheck.adguardhome.test.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761] states that DNS Registries/Registrars must not
|
||||
// grant requests to register test names in the normal way to any person or
|
||||
// entity, making domain names under test. TLD free to use in internal
|
||||
// purposes.
|
||||
//
|
||||
// [Section 6.2 of RFC 6761]: https://www.rfc-editor.org/rfc/rfc6761.html#section-6.2
|
||||
if q.Name == "healthcheck.adguardhome.test." {
|
||||
if q.Name == healthcheckFQDN {
|
||||
// Generate a NODATA negative response to make nslookup exit with 0.
|
||||
pctx.Res = s.makeResponse(pctx.Req)
|
||||
|
||||
@@ -213,11 +217,28 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
||||
dctx.setts = s.getClientRequestFilteringSettings(dctx)
|
||||
dctx.setts = s.clientRequestFilteringSettings(dctx)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||
func (s *Server) processClientIP(addr net.Addr) {
|
||||
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
|
||||
if clientIP == (netip.Addr{}) {
|
||||
log.Info("dnsforward: warning: bad client addr %q", addr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Do not assign s.addrProc to a local variable to then use, since this lock
|
||||
// also serializes the closure of s.addrProc.
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
s.addrProc.Process(clientIP)
|
||||
}
|
||||
|
||||
func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIPLock.Lock()
|
||||
defer s.tableHostToIPLock.Unlock()
|
||||
@@ -698,6 +719,18 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
if s.conf.UsePrivateRDNS {
|
||||
s.recDetector.add(*pctx.Req)
|
||||
if err := s.localResolvers.Resolve(pctx); err != nil {
|
||||
// Generate the server failure if the private upstream configuration
|
||||
// is empty.
|
||||
//
|
||||
// TODO(e.burkov): Get rid of this crutch once the local resolvers
|
||||
// logic is moved to the dnsproxy completely.
|
||||
if errors.Is(err, upstream.ErrNoUpstreams) {
|
||||
pctx.Res = s.genServerFailure(pctx.Req)
|
||||
|
||||
// Do not even put into query log.
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -22,6 +23,96 @@ const (
|
||||
ddrTestFQDN = ddrTestDomainName + "."
|
||||
)
|
||||
|
||||
func TestServer_ProcessInitial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
target string
|
||||
wantRCode rules.RCode
|
||||
qType rules.RRType
|
||||
aaaaDisabled bool
|
||||
wantRC resultCode
|
||||
}{{
|
||||
name: "success",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: -1,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeSuccess,
|
||||
}, {
|
||||
name: "aaaa_disabled",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
qType: dns.TypeAAAA,
|
||||
aaaaDisabled: true,
|
||||
wantRC: resultCodeFinish,
|
||||
}, {
|
||||
name: "aaaa_disabled_a",
|
||||
target: testQuestionTarget,
|
||||
wantRCode: -1,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: true,
|
||||
wantRC: resultCodeSuccess,
|
||||
}, {
|
||||
name: "mozilla_canary",
|
||||
target: mozillaFQDN,
|
||||
wantRCode: dns.RcodeNameError,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeFinish,
|
||||
}, {
|
||||
name: "adguardhome_healthcheck",
|
||||
target: healthcheckFQDN,
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
qType: dns.TypeA,
|
||||
aaaaDisabled: false,
|
||||
wantRC: resultCodeFinish,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
s := createTestServer(t, &filtering.Config{}, c, nil)
|
||||
|
||||
var gotAddr netip.Addr
|
||||
s.addrProc = &aghtest.AddressProcessor{
|
||||
OnProcess: func(ip netip.Addr) { gotAddr = ip },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: createTestMessageWithType(tc.target, tc.qType),
|
||||
Addr: testClientAddr,
|
||||
RequestID: 1234,
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processInitial(dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
|
||||
|
||||
if tc.wantRCode > 0 {
|
||||
gotResp := dctx.proxyCtx.Res
|
||||
require.NotNil(t, gotResp)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, gotResp.Rcode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
@@ -64,7 +155,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}{{
|
||||
name: "pass_host",
|
||||
wantRes: resultCodeSuccess,
|
||||
host: "example.net.",
|
||||
host: testQuestionTarget,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8043,
|
||||
@@ -234,33 +325,33 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
host string
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
isLocalCli bool
|
||||
}{{
|
||||
wantIP: knownIP,
|
||||
name: "local_client_success",
|
||||
host: "example.lan",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "local_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
isLocalCli: true,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_known_host",
|
||||
host: "example.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "external_client_unknown_host",
|
||||
host: "wronghost.lan",
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeFinish,
|
||||
isLocalCli: false,
|
||||
}}
|
||||
@@ -332,52 +423,52 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||
testCases := []struct {
|
||||
wantIP netip.Addr
|
||||
name string
|
||||
host string
|
||||
suffix string
|
||||
wantIP netip.Addr
|
||||
wantRes resultCode
|
||||
qtyp uint16
|
||||
}{{
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_external_non_a",
|
||||
host: examplecom,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeCNAME,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_internal",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_unknown",
|
||||
host: "example-new.lan",
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}, {
|
||||
wantIP: netip.Addr{},
|
||||
name: "success_internal_aaaa",
|
||||
host: examplelan,
|
||||
suffix: defaultLocalDomainSuffix,
|
||||
wantIP: netip.Addr{},
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
wantIP: knownIP,
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
suffix: "custom",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
@@ -560,10 +651,8 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
var dnsCtx *dnsContext
|
||||
setup := func(use bool) {
|
||||
proxyCtx = &proxy.DNSContext{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
},
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
Addr: testClientAddr,
|
||||
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
|
||||
}
|
||||
dnsCtx = &dnsContext{
|
||||
proxyCtx: proxyCtx,
|
||||
@@ -42,11 +42,13 @@ func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
// Use a customized set of RootCAs, because Go's default mechanism of
|
||||
// loading TLS roots does not always work properly on some routers so we're
|
||||
// loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
//
|
||||
// TODO(a.garipov): Investigate if that's true.
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
@@ -190,7 +192,7 @@ func (s *Server) resolveUpstreamsWithHosts(
|
||||
|
||||
// extractUpstreamHost returns the hostname of addr without port with an
|
||||
// assumption that any address passed here has already been successfully parsed
|
||||
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
|
||||
// by [upstream.AddressToUpstream]. This function essentially mirrors the logic
|
||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||
func extractUpstreamHost(addr string) (host string) {
|
||||
var err error
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghrenameio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -83,53 +84,53 @@ func (d *DNSFilter) filterSetProperties(
|
||||
filters = d.WhitelistFilters
|
||||
}
|
||||
|
||||
i := slices.IndexFunc(filters, func(filt FilterYAML) bool { return filt.URL == listURL })
|
||||
i := slices.IndexFunc(filters, func(flt FilterYAML) bool { return flt.URL == listURL })
|
||||
if i == -1 {
|
||||
return false, errFilterNotExist
|
||||
}
|
||||
|
||||
filt := &filters[i]
|
||||
flt := &filters[i]
|
||||
log.Debug(
|
||||
"filtering: set name to %q, url to %s, enabled to %t for filter %s",
|
||||
newList.Name,
|
||||
newList.URL,
|
||||
newList.Enabled,
|
||||
filt.URL,
|
||||
flt.URL,
|
||||
)
|
||||
|
||||
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time, oldRulesCount int) {
|
||||
if err != nil {
|
||||
filt.URL = oldURL
|
||||
filt.Name = oldName
|
||||
filt.Enabled = oldEnabled
|
||||
filt.LastUpdated = oldUpdated
|
||||
filt.RulesCount = oldRulesCount
|
||||
flt.URL = oldURL
|
||||
flt.Name = oldName
|
||||
flt.Enabled = oldEnabled
|
||||
flt.LastUpdated = oldUpdated
|
||||
flt.RulesCount = oldRulesCount
|
||||
}
|
||||
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated, filt.RulesCount)
|
||||
}(flt.URL, flt.Name, flt.Enabled, flt.LastUpdated, flt.RulesCount)
|
||||
|
||||
filt.Name = newList.Name
|
||||
flt.Name = newList.Name
|
||||
|
||||
if filt.URL != newList.URL {
|
||||
if flt.URL != newList.URL {
|
||||
if d.filterExistsLocked(newList.URL) {
|
||||
return false, errFilterExists
|
||||
}
|
||||
|
||||
shouldRestart = true
|
||||
|
||||
filt.URL = newList.URL
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.unload()
|
||||
flt.URL = newList.URL
|
||||
flt.LastUpdated = time.Time{}
|
||||
flt.unload()
|
||||
}
|
||||
|
||||
if filt.Enabled != newList.Enabled {
|
||||
filt.Enabled = newList.Enabled
|
||||
if flt.Enabled != newList.Enabled {
|
||||
flt.Enabled = newList.Enabled
|
||||
shouldRestart = true
|
||||
}
|
||||
|
||||
if filt.Enabled {
|
||||
if flt.Enabled {
|
||||
if shouldRestart {
|
||||
// Download the filter contents.
|
||||
shouldRestart, err = d.update(filt)
|
||||
shouldRestart, err = d.update(flt)
|
||||
}
|
||||
} else {
|
||||
// TODO(e.burkov): The validation of the contents of the new URL is
|
||||
@@ -137,7 +138,7 @@ func (d *DNSFilter) filterSetProperties(
|
||||
// possible to set a bad rules source, but the validation should still
|
||||
// kick in when the filter is enabled. Consider changing this behavior
|
||||
// to be stricter.
|
||||
filt.unload()
|
||||
flt.unload()
|
||||
}
|
||||
|
||||
return shouldRestart, err
|
||||
@@ -250,24 +251,24 @@ func assignUniqueFilterID() int64 {
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
const maxInterval = 1 * 60 * 60
|
||||
intval := 5 // use a dynamically increasing time interval
|
||||
ivl := 5 // use a dynamically increasing time interval
|
||||
for {
|
||||
isNetErr, ok := false, false
|
||||
if d.FiltersUpdateIntervalHours != 0 {
|
||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||
if ok && !isNetErr {
|
||||
intval = maxInterval
|
||||
ivl = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
if isNetErr {
|
||||
intval *= 2
|
||||
if intval > maxInterval {
|
||||
intval = maxInterval
|
||||
ivl *= 2
|
||||
if ivl > maxInterval {
|
||||
ivl = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(intval) * time.Second)
|
||||
time.Sleep(time.Duration(ivl) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,20 +330,20 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int,
|
||||
return 0, nil, nil, false
|
||||
}
|
||||
|
||||
nfail := 0
|
||||
failNum := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated, err := d.update(uf)
|
||||
updateFlags = append(updateFlags, updated)
|
||||
if err != nil {
|
||||
nfail++
|
||||
log.Info("filtering: updating filter from url %q: %s\n", uf.URL, err)
|
||||
failNum++
|
||||
log.Error("filtering: updating filter from url %q: %s\n", uf.URL, err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if nfail == len(updateFilters) {
|
||||
if failNum == len(updateFilters) {
|
||||
return 0, nil, nil, true
|
||||
}
|
||||
|
||||
@@ -464,48 +465,6 @@ func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) {
|
||||
return b, err
|
||||
}
|
||||
|
||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if succeeded.
|
||||
func (d *DNSFilter) finalizeUpdate(
|
||||
file *os.File,
|
||||
flt *FilterYAML,
|
||||
updated bool,
|
||||
res *rulelist.ParseResult,
|
||||
) (err error) {
|
||||
tmpFileName := file.Name()
|
||||
|
||||
// Close the file before renaming it because it's required on Windows.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing temporary file: %w", err)
|
||||
}
|
||||
|
||||
if !updated {
|
||||
log.Debug("filtering: filter %d from url %q has no changes, skipping", flt.ID, flt.URL)
|
||||
|
||||
return os.Remove(tmpFileName)
|
||||
}
|
||||
|
||||
fltPath := flt.Path(d.DataDir)
|
||||
|
||||
log.Info("filtering: saving contents of filter %d into %q", flt.ID, fltPath)
|
||||
|
||||
// Don't use renameio or maybe packages, since those will require loading
|
||||
// the whole filter content to the memory on Windows.
|
||||
err = os.Rename(tmpFileName, fltPath)
|
||||
if err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
|
||||
flt.Name = aghalg.Coalesce(flt.Name, res.Title)
|
||||
flt.checksum, flt.RulesCount = res.Checksum, res.RulesCount
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||
// the actual update has been performed.
|
||||
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
@@ -513,63 +472,22 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
|
||||
var res *rulelist.ParseResult
|
||||
|
||||
var tmpFile *os.File
|
||||
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
finErr := d.finalizeUpdate(tmpFile, flt, ok, res)
|
||||
if ok && finErr == nil {
|
||||
log.Info(
|
||||
"filtering: updated filter %d: %d bytes, %d rules",
|
||||
flt.ID,
|
||||
res.BytesWritten,
|
||||
res.RulesCount,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = errors.WithDeferred(err, finErr)
|
||||
}()
|
||||
|
||||
// Change the default 0o600 permission to something more acceptable by end
|
||||
// users.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3198.
|
||||
if err = tmpFile.Chmod(0o644); err != nil {
|
||||
return false, fmt.Errorf("changing file mode: %w", err)
|
||||
tmpFile, err := aghrenameio.NewPendingFile(flt.Path(d.DataDir), 0o644)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() { err = d.finalizeUpdate(tmpFile, flt, res, err, ok) }()
|
||||
|
||||
var r io.Reader
|
||||
if !filepath.IsAbs(flt.URL) {
|
||||
var resp *http.Response
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
if err != nil {
|
||||
log.Info("filtering: requesting filter from %q: %s, skipping", flt.URL, err)
|
||||
|
||||
return false, err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Info("filtering got status code %d from %q, skipping", resp.StatusCode, flt.URL)
|
||||
|
||||
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
r = resp.Body
|
||||
} else {
|
||||
var f *os.File
|
||||
f, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
r = f
|
||||
r, err := d.reader(flt.URL)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return false, err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, r.Close()) }()
|
||||
|
||||
bufPtr := d.bufPool.Get().(*[]byte)
|
||||
defer d.bufPool.Put(bufPtr)
|
||||
@@ -580,6 +498,78 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
return res.Checksum != flt.checksum && err == nil, err
|
||||
}
|
||||
|
||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if succeeded.
|
||||
func (d *DNSFilter) finalizeUpdate(
|
||||
file aghrenameio.PendingFile,
|
||||
flt *FilterYAML,
|
||||
res *rulelist.ParseResult,
|
||||
returned error,
|
||||
updated bool,
|
||||
) (err error) {
|
||||
id := flt.ID
|
||||
if !updated {
|
||||
if returned == nil {
|
||||
log.Debug("filtering: filter %d from url %q has no changes, skipping", id, flt.URL)
|
||||
}
|
||||
|
||||
return errors.WithDeferred(returned, file.Cleanup())
|
||||
}
|
||||
|
||||
log.Info("filtering: saving contents of filter %d into %q", id, flt.Path(d.DataDir))
|
||||
|
||||
err = file.CloseReplace()
|
||||
if err != nil {
|
||||
return fmt.Errorf("finalizing update: %w", err)
|
||||
}
|
||||
|
||||
rulesCount := res.RulesCount
|
||||
log.Info("filtering: updated filter %d: %d bytes, %d rules", id, res.BytesWritten, rulesCount)
|
||||
|
||||
flt.Name = aghalg.Coalesce(flt.Name, res.Title)
|
||||
flt.checksum = res.Checksum
|
||||
flt.RulesCount = rulesCount
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reader returns an io.ReadCloser reading filtering-rule list data form either
|
||||
// a file on the filesystem or the filter's HTTP URL.
|
||||
func (d *DNSFilter) reader(fltURL string) (r io.ReadCloser, err error) {
|
||||
if !filepath.IsAbs(fltURL) {
|
||||
r, err = d.readerFromURL(fltURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading from url: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
r, err = os.Open(fltURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening file: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// readerFromURL returns an io.ReadCloser reading filtering-rule list data form
|
||||
// the filter's URL.
|
||||
func (d *DNSFilter) readerFromURL(fltURL string) (r io.ReadCloser, err error) {
|
||||
resp, err := d.HTTPClient.Get(fltURL)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (d *DNSFilter) load(flt *FilterYAML) (err error) {
|
||||
fileName := flt.Path(d.DataDir)
|
||||
|
||||
@@ -943,7 +943,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
bufPool: &sync.Pool{
|
||||
New: func() (buf any) {
|
||||
bufVal := make([]byte, rulelist.MaxRuleLen)
|
||||
bufVal := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
|
||||
return &bufVal
|
||||
},
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"unicode"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Parser is a filtering-rule parser that collects data, such as the checksum
|
||||
@@ -48,19 +48,29 @@ type ParseResult struct {
|
||||
// nil.
|
||||
func (p *Parser) Parse(dst io.Writer, src io.Reader, buf []byte) (r *ParseResult, err error) {
|
||||
s := bufio.NewScanner(src)
|
||||
s.Buffer(buf, MaxRuleLen)
|
||||
|
||||
lineIdx := 0
|
||||
// Don't use [DefaultRuleBufSize] as the maximum size, since some
|
||||
// filtering-rule lists compressed by e.g. HostlistsCompiler can have very
|
||||
// large lines. The buffer optimization still works for the more common
|
||||
// case of reasonably-sized lines.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/6003.
|
||||
s.Buffer(buf, bufio.MaxScanTokenSize)
|
||||
|
||||
// Use a one-based index for lines and columns, since these errors end up in
|
||||
// the frontend, and users are more familiar with one-based line and column
|
||||
// indexes.
|
||||
lineNum := 1
|
||||
for s.Scan() {
|
||||
var n int
|
||||
n, err = p.processLine(dst, s.Bytes(), lineIdx)
|
||||
n, err = p.processLine(dst, s.Bytes(), lineNum)
|
||||
p.written += n
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return p.result(), err
|
||||
}
|
||||
|
||||
lineIdx++
|
||||
lineNum++
|
||||
}
|
||||
|
||||
r = p.result()
|
||||
@@ -81,7 +91,7 @@ func (p *Parser) result() (r *ParseResult) {
|
||||
|
||||
// processLine processes a single line. It may write to dst, and if it does, n
|
||||
// is the number of bytes written.
|
||||
func (p *Parser) processLine(dst io.Writer, line []byte, lineIdx int) (n int, err error) {
|
||||
func (p *Parser) processLine(dst io.Writer, line []byte, lineNum int) (n int, err error) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if p.written == 0 && isHTMLLine(trimmed) {
|
||||
return 0, ErrHTML
|
||||
@@ -95,9 +105,10 @@ func (p *Parser) processLine(dst io.Writer, line []byte, lineIdx int) (n int, er
|
||||
}
|
||||
if badIdx != -1 {
|
||||
return 0, fmt.Errorf(
|
||||
"line at index %d: character at index %d: non-printable character",
|
||||
lineIdx,
|
||||
badIdx+bytes.Index(line, trimmed),
|
||||
"line %d: character %d: likely binary character %q",
|
||||
lineNum,
|
||||
badIdx+bytes.Index(line, trimmed)+1,
|
||||
trimmed[badIdx],
|
||||
)
|
||||
}
|
||||
|
||||
@@ -130,41 +141,37 @@ func hasPrefixFold(b, prefix []byte) (ok bool) {
|
||||
}
|
||||
|
||||
// parseLine returns true if the parsed line is a filtering rule. line is
|
||||
// assumed to be trimmed of whitespace characters. nonPrintIdx is the index of
|
||||
// the first non-printable character, if any; if there are none, nonPrintIdx is
|
||||
// -1.
|
||||
// assumed to be trimmed of whitespace characters. badIdx is the index of the
|
||||
// first character that may indicate that this is a binary file, or -1 if none.
|
||||
//
|
||||
// A line is considered a rule if it's not empty, not a comment, and contains
|
||||
// only printable characters.
|
||||
func parseLine(line []byte) (nonPrintIdx int, isRule bool) {
|
||||
func parseLine(line []byte) (badIdx int, isRule bool) {
|
||||
if len(line) == 0 || line[0] == '#' || line[0] == '!' {
|
||||
return -1, false
|
||||
}
|
||||
|
||||
nonPrintIdx = bytes.IndexFunc(line, isNotPrintable)
|
||||
badIdx = slices.IndexFunc(line, likelyBinary)
|
||||
|
||||
return nonPrintIdx, nonPrintIdx == -1
|
||||
return badIdx, badIdx == -1
|
||||
}
|
||||
|
||||
// isNotPrintable returns true if r is not a printable character that can be
|
||||
// contained in a filtering rule.
|
||||
func isNotPrintable(r rune) (ok bool) {
|
||||
// Tab isn't included into Unicode's graphic symbols, so include it here
|
||||
// explicitly.
|
||||
return r != '\t' && !unicode.IsGraphic(r)
|
||||
// likelyBinary returns true if b is likely to be a byte from a binary file.
|
||||
func likelyBinary(b byte) (ok bool) {
|
||||
return (b < ' ' || b == 0x7f) && b != '\n' && b != '\r' && b != '\t'
|
||||
}
|
||||
|
||||
// parseLineTitle is like [parseLine] but additionally looks for a title. line
|
||||
// is assumed to be trimmed of whitespace characters.
|
||||
func (p *Parser) parseLineTitle(line []byte) (nonPrintIdx int, isRule bool) {
|
||||
func (p *Parser) parseLineTitle(line []byte) (badIdx int, isRule bool) {
|
||||
if len(line) == 0 || line[0] == '#' {
|
||||
return -1, false
|
||||
}
|
||||
|
||||
if line[0] != '!' {
|
||||
nonPrintIdx = bytes.IndexFunc(line, isNotPrintable)
|
||||
badIdx = slices.IndexFunc(line, likelyBinary)
|
||||
|
||||
return nonPrintIdx, nonPrintIdx == -1
|
||||
return badIdx, badIdx == -1
|
||||
}
|
||||
|
||||
const titlePattern = "! Title: "
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakeio"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -17,6 +17,9 @@ import (
|
||||
func TestParser_Parse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
longRule := strings.Repeat("a", rulelist.DefaultRuleBufSize+1) + "\n"
|
||||
tooLongRule := strings.Repeat("a", bufio.MaxScanTokenSize+1) + "\n"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
@@ -74,26 +77,42 @@ func TestParser_Parse(t *testing.T) {
|
||||
wantTitle: "Test Title",
|
||||
wantRulesNum: 1,
|
||||
wantWritten: len(testRuleTextBlocked),
|
||||
}, {
|
||||
name: "cosmetic_with_zwnj",
|
||||
in: testRuleTextCosmetic,
|
||||
wantDst: testRuleTextCosmetic,
|
||||
wantErrMsg: "",
|
||||
wantTitle: "",
|
||||
wantRulesNum: 1,
|
||||
wantWritten: len(testRuleTextCosmetic),
|
||||
}, {
|
||||
name: "bad_char",
|
||||
in: "! Title: Test Title \n" +
|
||||
testRuleTextBlocked +
|
||||
">>>\x7F<<<",
|
||||
wantDst: testRuleTextBlocked,
|
||||
wantErrMsg: "line at index 2: " +
|
||||
"character at index 3: " +
|
||||
"non-printable character",
|
||||
wantErrMsg: "line 3: " +
|
||||
"character 4: " +
|
||||
"likely binary character '\\x7f'",
|
||||
wantTitle: "Test Title",
|
||||
wantRulesNum: 1,
|
||||
wantWritten: len(testRuleTextBlocked),
|
||||
}, {
|
||||
name: "too_long",
|
||||
in: strings.Repeat("a", rulelist.MaxRuleLen+1),
|
||||
in: tooLongRule,
|
||||
wantDst: "",
|
||||
wantErrMsg: "scanning filter contents: " + bufio.ErrTooLong.Error(),
|
||||
wantErrMsg: "scanning filter contents: bufio.Scanner: token too long",
|
||||
wantTitle: "",
|
||||
wantRulesNum: 0,
|
||||
wantWritten: 0,
|
||||
}, {
|
||||
name: "longer_than_default",
|
||||
in: longRule,
|
||||
wantDst: longRule,
|
||||
wantErrMsg: "",
|
||||
wantTitle: "",
|
||||
wantRulesNum: 1,
|
||||
wantWritten: len(longRule),
|
||||
}, {
|
||||
name: "bad_tab_and_comment",
|
||||
in: testRuleTextBadTab,
|
||||
@@ -118,7 +137,7 @@ func TestParser_Parse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := &bytes.Buffer{}
|
||||
buf := make([]byte, rulelist.MaxRuleLen)
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
|
||||
p := rulelist.NewParser()
|
||||
r, err := p.Parse(dst, strings.NewReader(tc.in), buf)
|
||||
@@ -140,12 +159,12 @@ func TestParser_Parse(t *testing.T) {
|
||||
func TestParser_Parse_writeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := &aghtest.Writer{
|
||||
dst := &fakeio.Writer{
|
||||
OnWrite: func(b []byte) (n int, err error) {
|
||||
return 1, errors.Error("test error")
|
||||
},
|
||||
}
|
||||
buf := make([]byte, rulelist.MaxRuleLen)
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
|
||||
p := rulelist.NewParser()
|
||||
r, err := p.Parse(dst, strings.NewReader(testRuleTextBlocked), buf)
|
||||
@@ -165,7 +184,7 @@ func TestParser_Parse_checksums(t *testing.T) {
|
||||
"# Another comment.\n"
|
||||
)
|
||||
|
||||
buf := make([]byte, rulelist.MaxRuleLen)
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
|
||||
p := rulelist.NewParser()
|
||||
r, err := p.Parse(&bytes.Buffer{}, strings.NewReader(withoutComments), buf)
|
||||
@@ -192,7 +211,7 @@ var (
|
||||
func BenchmarkParser_Parse(b *testing.B) {
|
||||
dst := &bytes.Buffer{}
|
||||
src := strings.NewReader(strings.Repeat(testRuleTextBlocked, 1000))
|
||||
buf := make([]byte, rulelist.MaxRuleLen)
|
||||
buf := make([]byte, rulelist.DefaultRuleBufSize)
|
||||
p := rulelist.NewParser()
|
||||
|
||||
b.ReportAllocs()
|
||||
@@ -204,6 +223,14 @@ func BenchmarkParser_Parse(b *testing.B) {
|
||||
|
||||
require.NoError(b, errSink)
|
||||
require.NotNil(b, resSink)
|
||||
|
||||
// Most recent result, on a ThinkPad X13 with a Ryzen Pro 7 CPU:
|
||||
//
|
||||
// goos: linux
|
||||
// goarch: amd64
|
||||
// pkg: github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist
|
||||
// cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics
|
||||
// BenchmarkParser_Parse-16 100000000 128.0 ns/op 48 B/op 1 allocs/op
|
||||
}
|
||||
|
||||
func FuzzParser_Parse(f *testing.F) {
|
||||
@@ -215,15 +242,17 @@ func FuzzParser_Parse(f *testing.F) {
|
||||
"! Comment",
|
||||
"! Title ",
|
||||
"! Title XXX",
|
||||
testRuleTextBadTab,
|
||||
testRuleTextBlocked,
|
||||
testRuleTextCosmetic,
|
||||
testRuleTextEtcHostsTab,
|
||||
testRuleTextHTML,
|
||||
testRuleTextBlocked,
|
||||
testRuleTextBadTab,
|
||||
"1.2.3.4",
|
||||
"1.2.3.4 etc-hosts.example",
|
||||
">>>\x00<<<",
|
||||
">>>\x7F<<<",
|
||||
strings.Repeat("a", n+1),
|
||||
strings.Repeat("a", rulelist.DefaultRuleBufSize+1),
|
||||
strings.Repeat("a", bufio.MaxScanTokenSize+1),
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
// TODO(a.garipov): Expand.
|
||||
package rulelist
|
||||
|
||||
// MaxRuleLen is the maximum length of a line with a filtering rule, in bytes.
|
||||
//
|
||||
// TODO(a.garipov): Consider changing this to a rune length, like AdGuardDNS
|
||||
// does.
|
||||
const MaxRuleLen = 1024
|
||||
// DefaultRuleBufSize is the default length of a buffer used to read a line with
|
||||
// a filtering rule, in bytes.
|
||||
const DefaultRuleBufSize = 1024
|
||||
|
||||
@@ -7,8 +7,13 @@ const testTimeout = 1 * time.Second
|
||||
|
||||
// Common texts for tests.
|
||||
const (
|
||||
testRuleTextHTML = "<!DOCTYPE html>\n"
|
||||
testRuleTextBlocked = "||blocked.example^\n"
|
||||
testRuleTextBadTab = "||bad-tab-and-comment.example^\t# A comment.\n"
|
||||
testRuleTextBlocked = "||blocked.example^\n"
|
||||
testRuleTextEtcHostsTab = "0.0.0.0 tab..example^\t# A comment.\n"
|
||||
testRuleTextHTML = "<!DOCTYPE html>\n"
|
||||
|
||||
// testRuleTextCosmetic is a cosmetic rule with a zero-width non-joiner.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/6003.
|
||||
testRuleTextCosmetic = "||cosmetic.example## :has-text(/\u200c/i)\n"
|
||||
)
|
||||
|
||||
@@ -89,37 +89,34 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
assert.False(t, res.IsFiltered)
|
||||
assert.Empty(t, res.Rules)
|
||||
|
||||
resolver := &aghtest.TestResolver{}
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
},
|
||||
}
|
||||
|
||||
ss = newForTest(t, defaultSafeSearchConf)
|
||||
ss.resolver = resolver
|
||||
|
||||
// Lookup for safesearch domain.
|
||||
rewrite := ss.searchHost(domain, testQType)
|
||||
|
||||
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundIP net.IP
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
foundIP = ip
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
|
||||
|
||||
res, err = ss.CheckHost(domain, testQType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.True(t, res.Rules[0].IP.Equal(foundIP))
|
||||
assert.True(t, res.Rules[0].IP.Equal(wantIP))
|
||||
|
||||
// Check cache.
|
||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
||||
require.True(t, isFound)
|
||||
require.Len(t, cachedValue.Rules, 1)
|
||||
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
|
||||
assert.True(t, cachedValue.Rules[0].IP.Equal(wantIP))
|
||||
}
|
||||
|
||||
const googleHost = "www.google.com"
|
||||
|
||||
@@ -92,8 +92,15 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_google(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
resolver := &aghtest.Resolver{
|
||||
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
ip4, ip6 := aghtest.HostToIPs(host)
|
||||
|
||||
return []net.IP{ip4, ip6}, nil
|
||||
},
|
||||
}
|
||||
|
||||
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||
|
||||
conf := testConf
|
||||
conf.CustomResolver = resolver
|
||||
@@ -119,7 +126,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
||||
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Equal(t, ip, res.Rules[0].IP)
|
||||
assert.Equal(t, wantIP, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1505,7 +1505,6 @@ var blockedServices = []blockedService{{
|
||||
"||aus.social^",
|
||||
"||awscommunity.social^",
|
||||
"||climatejustice.social^",
|
||||
"||cupoftea.social^",
|
||||
"||cyberplace.social^",
|
||||
"||defcon.social^",
|
||||
"||det.social^",
|
||||
@@ -1589,6 +1588,7 @@ var blockedServices = []blockedService{{
|
||||
"||techhub.social^",
|
||||
"||theblower.au^",
|
||||
"||tkz.one^",
|
||||
"||todon.eu^",
|
||||
"||toot.aquilenet.fr^",
|
||||
"||toot.community^",
|
||||
"||toot.funami.tech^",
|
||||
@@ -1661,6 +1661,7 @@ var blockedServices = []blockedService{{
|
||||
"||nintendo.jp^",
|
||||
"||nintendo.net^",
|
||||
"||nintendo.nl^",
|
||||
"||nintendo.pt^",
|
||||
"||nintendoswitch.cn^",
|
||||
"||nintendowifi.net^",
|
||||
},
|
||||
@@ -2160,6 +2161,20 @@ var blockedServices = []blockedService{{
|
||||
Rules: []string{
|
||||
"||voot.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "wargaming",
|
||||
Name: "Wargaming",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 24 24\"><path d=\"M12 1.998c-5.52 0-10 4.481-10 9.988 0 5.52 4.48 9.996 10 9.996s10-4.476 10-9.996c0-5.507-4.48-9.988-10-9.988zm0 2c4.413 0 8 3.588 8 7.988 0 3.246-1.944 6.04-4.727 7.293.54-1.861.831-3.988.807-6.226l1.414.414a23.648 23.648 0 0 0-2-4.041c-.627 1.347-1.48 2.56-2.52 3.68l1.68-.133c-1.507 2.92-3.134 3.906-5.547 4.013-.386-4.213.12-7.014 2.827-9.04l.386 1.493c.653-.974 1.36-2.12 2.373-2.947-1.506-.6-2.999-.627-4.492-.334.386.16.76.588 1.014.828-3.485 1.662-5.643 4.202-6.744 7.68A7.95 7.95 0 0 1 4 11.986c0-4.4 3.587-7.988 8-7.988z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||wargaming.com^",
|
||||
"||wargaming.net^",
|
||||
"||wgcdn.co^",
|
||||
"||wgcrowd.io^",
|
||||
"||worldoftanks.com^",
|
||||
"||worldofwarplanes.com^",
|
||||
"||worldofwarships.eu^",
|
||||
"||wotblitz.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "wechat",
|
||||
Name: "WeChat",
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
@@ -141,7 +142,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
|
||||
}
|
||||
}
|
||||
|
||||
// webHandlersRegistered prevents a [clientsContainer] from regisering its web
|
||||
// webHandlersRegistered prevents a [clientsContainer] from registering its web
|
||||
// handlers more than once.
|
||||
//
|
||||
// TODO(a.garipov): Refactor HTTP handler registration logic.
|
||||
@@ -743,11 +744,9 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a client.
|
||||
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
|
||||
// expected to be locked.
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
@@ -774,9 +773,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
rc.WHOIS = wi
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(
|
||||
//
|
||||
// TODO(a.garipov): Only used in internal tests. Consider removing.
|
||||
func (clients *clientsContainer) addHost(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src clientSource,
|
||||
@@ -787,6 +788,32 @@ func (clients *clientsContainer) AddHost(
|
||||
return clients.addHostLocked(ip, host, src)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||
|
||||
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||
// *clientsContainer
|
||||
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||
// Common fast path optimization.
|
||||
if host == "" && info == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
if host != "" {
|
||||
ok := clients.addHostLocked(ip, host, ClientSourceRDNS)
|
||||
if !ok {
|
||||
log.Debug("clients: host for client %q already set with higher priority source", ip)
|
||||
}
|
||||
}
|
||||
|
||||
if info != nil {
|
||||
clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) addHostLocked(
|
||||
|
||||
@@ -168,13 +168,13 @@ func TestClients(t *testing.T) {
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.AddHost(ip, "host", ClientSourceARP)
|
||||
ok := clients.addHost(ip, "host", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||
ok = clients.addHost(ip, "host2", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||
ok = clients.addHost(ip, "host3", ClientSourceHostsFile)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceHostsFile)
|
||||
@@ -182,18 +182,18 @@ func TestClients(t *testing.T) {
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||
ok := clients.addHost(ip, "from_arp", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceARP)
|
||||
|
||||
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
ok = clients.addHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, clients.clientSource(ip), ClientSourceDHCP)
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
|
||||
ok := clients.addHost(ip, "host1", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||
ok := clients.addHost(ip, "host", ClientSourceRDNS)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
@@ -259,7 +259,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||
ok = clients.addHost(ip, "test", ClientSourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/exp/slices"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -590,7 +590,13 @@ func (c *configuration) write() (err error) {
|
||||
s.WriteDiskConfig(&c)
|
||||
dns := &config.DNS
|
||||
dns.FilteringConfig = c
|
||||
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
|
||||
|
||||
dns.LocalPTRResolvers = s.LocalPTRResolvers()
|
||||
|
||||
addrProcConf := s.AddrProcConfig()
|
||||
config.Clients.Sources.RDNS = addrProcConf.UseRDNS
|
||||
config.Clients.Sources.WHOIS = addrProcConf.UseWHOIS
|
||||
dns.UsePrivateRDNS = addrProcConf.UsePrivateRDNS
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
|
||||
@@ -176,12 +176,16 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
func registerControlHandlers() {
|
||||
func registerControlHandlers(web *webAPI) {
|
||||
Context.mux.HandleFunc(
|
||||
"/control/version.json",
|
||||
postInstall(optionalAuth(web.handleVersionJSON)),
|
||||
)
|
||||
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
|
||||
|
||||
httpRegister(http.MethodGet, "/control/status", handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleVersionJSON)))
|
||||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||
|
||||
|
||||
@@ -448,7 +448,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
web.conf.BindHost = req.Web.IP
|
||||
web.conf.BindPort = req.Web.Port
|
||||
|
||||
registerControlHandlers()
|
||||
registerControlHandlers(web)
|
||||
|
||||
aghhttp.OK(w)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
|
||||
@@ -29,9 +29,9 @@ type temporaryError interface {
|
||||
// handleVersionJSON is the handler for the POST /control/version.json HTTP API.
|
||||
//
|
||||
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
|
||||
func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &versionResponse{}
|
||||
if Context.disableUpdate {
|
||||
if web.conf.disableUpdate {
|
||||
resp.Disabled = true
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
|
||||
@@ -52,7 +52,7 @@ func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = requestVersionInfo(resp, req.Recheck)
|
||||
err = web.requestVersionInfo(resp, req.Recheck)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
|
||||
@@ -73,9 +73,10 @@ func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
// update server.
|
||||
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
updater := web.conf.updater
|
||||
for i := 0; i != 3; i++ {
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
|
||||
resp.VersionInfo, err = updater.VersionInfo(recheck)
|
||||
if err != nil {
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
@@ -95,7 +96,7 @@ func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
vcu := updater.VersionCheckURL()
|
||||
|
||||
return fmt.Errorf("getting version info from %s: %w", vcu, err)
|
||||
}
|
||||
@@ -104,8 +105,9 @@ func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
}
|
||||
|
||||
// handleUpdate performs an update to the latest available version procedure.
|
||||
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
if Context.updater.NewVersion() == "" {
|
||||
func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
updater := web.conf.updater
|
||||
if updater.NewVersion() == "" {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
|
||||
return
|
||||
@@ -122,7 +124,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update(false)
|
||||
err = updater.Update(false)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
@@ -137,7 +139,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), execPath)
|
||||
go finishUpdate(context.Background(), execPath, web.conf.runningAsService)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
@@ -178,7 +180,7 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context, execPath string) {
|
||||
func finishUpdate(ctx context.Context, execPath string, runningAsService bool) {
|
||||
var err error
|
||||
|
||||
log.Info("stopping all tasks")
|
||||
@@ -187,7 +189,7 @@ func finishUpdate(ctx context.Context, execPath string) {
|
||||
cleanupAlways()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if Context.runningAsService {
|
||||
if runningAsService {
|
||||
// NOTE: We can't restart the service via "kardianos/service"
|
||||
// package, because it kills the process first we can't start a new
|
||||
// instance, because Windows doesn't allow it.
|
||||
|
||||
@@ -13,14 +13,12 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -135,7 +133,7 @@ func initDNSServer(
|
||||
return fmt.Errorf("preparing set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
@@ -143,9 +141,7 @@ func initDNSServer(
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
DHCPServer: dhcpSrv,
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||
})
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
@@ -154,134 +150,23 @@ func initDNSServer(
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
|
||||
dnsConf, err := generateServerConfig(tlsConf, httpReg)
|
||||
dnsConf, err := newServerConfig(tlsConf, httpReg)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return fmt.Errorf("generateServerConfig: %w", err)
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Prepare(&dnsConf)
|
||||
err = Context.dnsServer.Prepare(dnsConf)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||
}
|
||||
|
||||
initRDNS()
|
||||
initWHOIS()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
|
||||
// processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
|
||||
// processing. It must be greater than zero.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for IP addresses cached by
|
||||
// rDNS and WHOIS.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
// initRDNS initializes the rDNS.
|
||||
func initRDNS() {
|
||||
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
|
||||
|
||||
// TODO(s.chzhen): Add ability to disable it on dns server configuration
|
||||
// update in [dnsforward] package.
|
||||
r := rdns.New(&rdns.Config{
|
||||
Exchanger: Context.dnsServer,
|
||||
CacheSize: defaultCacheSize,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
|
||||
go processRDNS(r)
|
||||
}
|
||||
|
||||
// processRDNS processes reverse DNS lookup queries. It is intended to be used
|
||||
// as a goroutine.
|
||||
func processRDNS(r rdns.Interface) {
|
||||
defer log.OnPanic("rdns")
|
||||
|
||||
for ip := range Context.rdnsCh {
|
||||
ok := Context.dnsServer.ShouldResolveClient(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
host, changed := r.Process(ip)
|
||||
if host == "" || !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"dns: can't set rdns info for client %q: already set with higher priority source",
|
||||
ip,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// initWHOIS initializes the WHOIS.
|
||||
//
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
func initWHOIS() {
|
||||
const (
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
defaultTimeout = 5 * time.Second
|
||||
|
||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from
|
||||
// net.Conn.
|
||||
defaultMaxConnReadSize = 64 * 1024
|
||||
|
||||
// defaultMaxRedirects is the maximum redirects count.
|
||||
defaultMaxRedirects = 5
|
||||
|
||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
||||
defaultMaxInfoLen = 250
|
||||
)
|
||||
|
||||
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
|
||||
|
||||
var w whois.Interface
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
w = whois.New(&whois.Config{
|
||||
DialContext: customDialContext,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
Timeout: defaultTimeout,
|
||||
CacheSize: defaultCacheSize,
|
||||
MaxConnReadSize: defaultMaxConnReadSize,
|
||||
MaxRedirects: defaultMaxRedirects,
|
||||
MaxInfoLen: defaultMaxInfoLen,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
} else {
|
||||
w = whois.Empty{}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("whois")
|
||||
|
||||
for ip := range Context.whoisCh {
|
||||
info, changed := w.Process(context.Background(), ip)
|
||||
if info != nil && changed {
|
||||
Context.clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
@@ -312,17 +197,6 @@ func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
|
||||
if ip == (netip.Addr{}) {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
}
|
||||
|
||||
Context.rdnsCh <- ip
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
@@ -349,23 +223,41 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
func generateServerConfig(
|
||||
func newServerConfig(
|
||||
tlsConf *tlsConfigSettings,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
) (newConf dnsforward.ServerConfig, err error) {
|
||||
) (newConf *dnsforward.ServerConfig, err error) {
|
||||
dnsConf := config.DNS
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
newConf = dnsforward.ServerConfig{
|
||||
|
||||
newConf = &dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
FilteringConfig: dnsConf.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
OnDNSRequest: onDNSRequest,
|
||||
UseDNS64: config.DNS.UseDNS64,
|
||||
DNS64Prefixes: config.DNS.DNS64Prefixes,
|
||||
}
|
||||
|
||||
var initialAddresses []netip.Addr
|
||||
// Context.stats may be nil here if initDNSServer is called from
|
||||
// [cmdlineUpdate].
|
||||
if sts := Context.stats; sts != nil {
|
||||
const initialClientsNum = 100
|
||||
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
|
||||
}
|
||||
|
||||
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
|
||||
// are set by [dnsforward.Server.Prepare].
|
||||
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
|
||||
Exchanger: Context.dnsServer,
|
||||
AddressUpdater: &Context.clients,
|
||||
InitialAddresses: initialAddresses,
|
||||
UseRDNS: config.Clients.Sources.RDNS,
|
||||
UseWHOIS: config.Clients.Sources.WHOIS,
|
||||
}
|
||||
|
||||
if tlsConf.Enabled {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
@@ -385,9 +277,9 @@ func generateServerConfig(
|
||||
if tlsConf.PortDNSCrypt != 0 {
|
||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's already
|
||||
// wrapped by newDNSCrypt.
|
||||
return dnsforward.ServerConfig{}, err
|
||||
// Don't wrap the error, because it's already wrapped by
|
||||
// newDNSCrypt.
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -401,7 +293,6 @@ func generateServerConfig(
|
||||
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
|
||||
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
|
||||
|
||||
newConf.ResolveClients = config.Clients.Sources.RDNS
|
||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||
newConf.ServeHTTP3 = dnsConf.ServeHTTP3
|
||||
newConf.UseHTTP3Upstreams = dnsConf.UseHTTP3Upstreams
|
||||
@@ -556,27 +447,19 @@ func startDNSServer() error {
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
Context.rdnsCh <- ip
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func reconfigureDNSServer() (err error) {
|
||||
var newConf dnsforward.ServerConfig
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
newConf, err = generateServerConfig(tlsConf, httpRegister)
|
||||
newConf, err := newServerConfig(tlsConf, httpRegister)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Reconfigure(&newConf)
|
||||
err = Context.dnsServer.Reconfigure(newConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting forwarding dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,14 +3,12 @@ package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -66,40 +64,24 @@ type homeContext struct {
|
||||
// configuration files, for example /etc/hosts.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
// mux is our custom http.ServeMux.
|
||||
mux *http.ServeMux
|
||||
|
||||
// Runtime properties
|
||||
// --
|
||||
|
||||
configFilename string // Config filename (can be overridden via the command line arguments)
|
||||
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
|
||||
// rdnsCh is the channel for receiving IPs for rDNS processing.
|
||||
rdnsCh chan netip.Addr
|
||||
|
||||
// whoisCh is the channel for receiving IPs for WHOIS processing.
|
||||
whoisCh chan netip.Addr
|
||||
configFilename string // Config filename (can be overridden via the command line arguments)
|
||||
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
|
||||
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
tlsCipherIDs []uint16
|
||||
|
||||
// disableUpdate, if true, tells AdGuard Home to not check for updates.
|
||||
disableUpdate bool
|
||||
|
||||
// firstRun, if true, tells AdGuard Home to only start the web interface
|
||||
// service, and only serve the first-run APIs.
|
||||
firstRun bool
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
}
|
||||
|
||||
// getDataDir returns path to the directory where we store databases and filters
|
||||
@@ -122,11 +104,11 @@ func Main(clientBuildFS fs.FS) {
|
||||
// package flag.
|
||||
opts := loadCmdLineOpts()
|
||||
|
||||
Context.appSignalChannel = make(chan os.Signal)
|
||||
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
go func() {
|
||||
for {
|
||||
sig := <-Context.appSignalChannel
|
||||
sig := <-signals
|
||||
log.Info("Received signal %q", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
@@ -141,7 +123,7 @@ func Main(clientBuildFS fs.FS) {
|
||||
}()
|
||||
|
||||
if opts.serviceControlAction != "" {
|
||||
handleServiceControlAction(opts, clientBuildFS)
|
||||
handleServiceControlAction(opts, clientBuildFS, signals)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -153,74 +135,48 @@ func Main(clientBuildFS fs.FS) {
|
||||
// setupContext initializes [Context] fields. It also reads and upgrades
|
||||
// config file if necessary.
|
||||
func setupContext(opts options) (err error) {
|
||||
setupContextFlags(opts)
|
||||
Context.firstRun = detectFirstRun()
|
||||
|
||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
||||
Context.client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Context.mux = http.NewServeMux()
|
||||
|
||||
if !Context.firstRun {
|
||||
// Do the upgrade if necessary.
|
||||
err = upgradeConfig()
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do the upgrade if necessary.
|
||||
err = upgradeConfig()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = parseConfig(); err != nil {
|
||||
log.Error("parsing configuration file: %s", err)
|
||||
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if opts.checkConfig {
|
||||
log.Info("configuration file is ok")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
err = setupHostsContainer()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if err = parseConfig(); err != nil {
|
||||
log.Error("parsing configuration file: %s", err)
|
||||
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if opts.checkConfig {
|
||||
log.Info("configuration file is ok")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
err = setupHostsContainer()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupContextFlags sets global flags and prints their status to the log.
|
||||
func setupContextFlags(opts options) {
|
||||
Context.firstRun = detectFirstRun()
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
}
|
||||
|
||||
Context.runningAsService = opts.runningAsService
|
||||
// Don't print the runningAsService flag, since that has already been done
|
||||
// in [run].
|
||||
|
||||
Context.disableUpdate = opts.disableUpdate || version.Channel() == version.ChannelDevelopment
|
||||
if Context.disableUpdate {
|
||||
log.Info("AdGuard Home updates are disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// logIfUnsupported logs a formatted warning if the error is one of the
|
||||
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
|
||||
// nil. Otherwise, it returns err.
|
||||
@@ -325,7 +281,7 @@ func initContextClients() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.DataDir = Context.getDataDir()
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
@@ -340,18 +296,6 @@ func initContextClients() (err error) {
|
||||
return fmt.Errorf("initing dhcp: %w", err)
|
||||
}
|
||||
|
||||
Context.updater = updater.NewUpdater(&updater.Config{
|
||||
Client: Context.client,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
var arpdb aghnet.ARPDB
|
||||
if config.Clients.Sources.ARP {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
@@ -433,7 +377,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
conf.Filters = slices.Clone(config.Filters)
|
||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
conf.UserRules = slices.Clone(config.UserRules)
|
||||
conf.HTTPClient = Context.client
|
||||
conf.HTTPClient = httpClient()
|
||||
|
||||
cacheTime := time.Duration(conf.CacheTime) * time.Minute
|
||||
|
||||
@@ -515,7 +459,7 @@ func checkPorts() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webAPI, err error) {
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
@@ -528,8 +472,16 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
webConf := webConfig{
|
||||
firstRun: Context.firstRun,
|
||||
disableUpdate := opts.disableUpdate || version.Channel() == version.ChannelDevelopment
|
||||
if disableUpdate {
|
||||
log.Info("AdGuard Home updates are disabled")
|
||||
}
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
BindHost: config.HTTPConfig.Address.Addr(),
|
||||
BindPort: int(config.HTTPConfig.Address.Port()),
|
||||
|
||||
@@ -537,12 +489,13 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
firstRun: Context.firstRun,
|
||||
disableUpdate: disableUpdate,
|
||||
runningAsService: opts.runningAsService,
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
}
|
||||
|
||||
web = newWebAPI(&webConf)
|
||||
web = newWebAPI(webConf)
|
||||
if web == nil {
|
||||
return nil, fmt.Errorf("initializing web: %w", err)
|
||||
}
|
||||
@@ -593,9 +546,21 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
err = setupOpts(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
upd := updater.NewUpdater(&updater.Config{
|
||||
Client: config.DNS.DnsfilterConf.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts)
|
||||
cmdlineUpdate(opts, upd)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -624,7 +589,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(opts, clientBuildFS)
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
@@ -634,10 +599,10 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
Context.tls.start()
|
||||
|
||||
go func() {
|
||||
sErr := startDNSServer()
|
||||
if sErr != nil {
|
||||
startErr := startDNSServer()
|
||||
if startErr != nil {
|
||||
closeDNSServer()
|
||||
fatalOnError(sErr)
|
||||
fatalOnError(startErr)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -996,62 +961,6 @@ func detectFirstRun() bool {
|
||||
return errors.Is(err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server.
|
||||
//
|
||||
// TODO(e.burkov): This messy logic should be decomposed and clarified.
|
||||
//
|
||||
// TODO(a.garipov): Support network.
|
||||
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
addrs, err := Context.dnsServer.Resolve(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving %q: %w", host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsServer.Resolve: %q: %v", host, addrs)
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("couldn't lookup host: %q", host)
|
||||
}
|
||||
|
||||
var dialErrs []error
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
conn, err = dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
dialErrs = append(dialErrs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||
}
|
||||
|
||||
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
||||
if config.ProxyURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return url.Parse(config.ProxyURL)
|
||||
}
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
|
||||
@@ -1062,7 +971,7 @@ type jsonError struct {
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits.
|
||||
func cmdlineUpdate(opts options) {
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
@@ -1077,10 +986,9 @@ func cmdlineUpdate(opts options) {
|
||||
|
||||
log.Info("cmdline update: performing update")
|
||||
|
||||
updater := Context.updater
|
||||
info, err := updater.VersionInfo(true)
|
||||
info, err := upd.VersionInfo(true)
|
||||
if err != nil {
|
||||
vcu := updater.VersionCheckURL()
|
||||
vcu := upd.VersionCheckURL()
|
||||
log.Error("getting version info from %s: %s", vcu, err)
|
||||
|
||||
os.Exit(1)
|
||||
@@ -1092,7 +1000,7 @@ func cmdlineUpdate(opts options) {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = updater.Update(Context.firstRun)
|
||||
err = upd.Update(Context.firstRun)
|
||||
fatalOnError(err)
|
||||
|
||||
err = restartService()
|
||||
|
||||
47
internal/home/httpclient.go
Normal file
47
internal/home/httpclient.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// httpClient returns a new HTTP client that uses the AdGuard Home's own DNS
|
||||
// server for resolving hostnames. The resulting client should not be used
|
||||
// until [Context.dnsServer] is initialized.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This is rather messy. Refactor.
|
||||
func httpClient() (c *http.Client) {
|
||||
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
||||
// below, since Context.dnsServer may be nil when this function is called.
|
||||
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
return Context.dnsServer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: &http.Transport{
|
||||
DialContext: dialContext,
|
||||
Proxy: httpProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// httpProxy returns parses and returns an HTTP proxy URL from the config, if
|
||||
// any.
|
||||
func httpProxy(_ *http.Request) (u *url.URL, err error) {
|
||||
if config.ProxyURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return url.Parse(config.ProxyURL)
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// startPprof launches the debug and profiling server on addr.
|
||||
func startPprof(addr string) {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
|
||||
// See profileSupportsDelta in src/net/http/pprof/pprof.go.
|
||||
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
|
||||
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
|
||||
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
|
||||
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
|
||||
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
|
||||
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
log.Info("pprof: listening on %q", addr)
|
||||
err := http.ListenAndServe(addr, mux)
|
||||
log.Info("pprof server errors: %v", err)
|
||||
}()
|
||||
}
|
||||
@@ -33,9 +33,13 @@ const (
|
||||
// daemon.
|
||||
type program struct {
|
||||
clientBuildFS fs.FS
|
||||
signals chan os.Signal
|
||||
opts options
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ service.Interface = (*program)(nil)
|
||||
|
||||
// Start implements service.Interface interface for *program.
|
||||
func (p *program) Start(_ service.Service) (err error) {
|
||||
// Start should not block. Do the actual work async.
|
||||
@@ -48,14 +52,14 @@ func (p *program) Start(_ service.Service) (err error) {
|
||||
}
|
||||
|
||||
// Stop implements service.Interface interface for *program.
|
||||
func (p *program) Stop(_ service.Service) error {
|
||||
// Stop should not block. Return with a few seconds.
|
||||
if Context.appSignalChannel == nil {
|
||||
os.Exit(0)
|
||||
func (p *program) Stop(_ service.Service) (err error) {
|
||||
select {
|
||||
case p.signals <- syscall.SIGINT:
|
||||
// Go on.
|
||||
default:
|
||||
// Stop should not block.
|
||||
}
|
||||
|
||||
Context.appSignalChannel <- syscall.SIGINT
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -194,7 +198,7 @@ func restartService() (err error) {
|
||||
// - run: This is a special command that is not supposed to be used directly
|
||||
// it is specified when we register a service, and it indicates to the app
|
||||
// that it is being run as a service/daemon.
|
||||
func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
|
||||
func handleServiceControlAction(opts options, clientBuildFS fs.FS, signals chan os.Signal) {
|
||||
// Call chooseSystem explicitly to introduce OpenBSD support for service
|
||||
// package. It's a noop for other GOOS values.
|
||||
chooseSystem()
|
||||
@@ -226,7 +230,11 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
configureService(svcConfig)
|
||||
|
||||
s, err := service.New(&program{clientBuildFS: clientBuildFS, opts: runOpts}, svcConfig)
|
||||
s, err := service.New(&program{
|
||||
clientBuildFS: clientBuildFS,
|
||||
signals: signals,
|
||||
opts: runOpts,
|
||||
}, svcConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("service: initializing service: %s", err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -6,16 +6,18 @@ import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/pprofutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
@@ -33,6 +35,8 @@ const (
|
||||
)
|
||||
|
||||
type webConfig struct {
|
||||
updater *updater.Updater
|
||||
|
||||
clientFS fs.FS
|
||||
|
||||
BindHost netip.Addr
|
||||
@@ -52,6 +56,13 @@ type webConfig struct {
|
||||
|
||||
firstRun bool
|
||||
|
||||
// disableUpdate, if true, tells AdGuard Home to not check for updates.
|
||||
disableUpdate bool
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the
|
||||
// service runner.
|
||||
runningAsService bool
|
||||
|
||||
serveHTTP3 bool
|
||||
}
|
||||
|
||||
@@ -102,7 +113,7 @@ func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
|
||||
w.registerInstallHandlers()
|
||||
} else {
|
||||
registerControlHandlers()
|
||||
registerControlHandlers(w)
|
||||
}
|
||||
|
||||
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
|
||||
@@ -295,8 +306,27 @@ func (web *webAPI) mustStartHTTP3(address string) {
|
||||
|
||||
log.Debug("web: starting http/3 server")
|
||||
err := web.httpsServer.server3.ListenAndServe()
|
||||
if !errors.Is(err, quic.ErrServerClosed) {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
cleanupAlways()
|
||||
log.Fatalf("web: http3: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startPprof launches the debug and profiling server on addr.
|
||||
func startPprof(addr string) {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
pprofutil.RoutePprof(mux)
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
log.Info("pprof: listening on %q", addr)
|
||||
err := http.ListenAndServe(addr, mux)
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Error("pprof: shutting down: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/bluele/gcache"
|
||||
)
|
||||
|
||||
@@ -17,7 +18,7 @@ type Interface interface {
|
||||
Process(ip netip.Addr) (host string, changed bool)
|
||||
}
|
||||
|
||||
// Empty is an empty [Inteface] implementation which does nothing.
|
||||
// Empty is an empty [Interface] implementation which does nothing.
|
||||
type Empty struct{}
|
||||
|
||||
// type check
|
||||
@@ -32,7 +33,7 @@ func (Empty) Process(_ netip.Addr) (host string, changed bool) {
|
||||
type Exchanger interface {
|
||||
// Exchange tries to resolve the ip in a suitable way, i.e. either as local
|
||||
// or as external.
|
||||
Exchange(ip netip.Addr) (host string, err error)
|
||||
Exchange(ip netip.Addr) (host string, ttl time.Duration, err error)
|
||||
}
|
||||
|
||||
// Config is the configuration structure for Default.
|
||||
@@ -82,13 +83,16 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
|
||||
return fromCache, false
|
||||
}
|
||||
|
||||
host, err := r.exchanger.Exchange(ip)
|
||||
host, ttl, err := r.exchanger.Exchange(ip)
|
||||
if err != nil {
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Use built-in function max in Go 1.21.
|
||||
ttl = mathutil.Max(ttl, r.cacheTTL)
|
||||
|
||||
item := &cacheItem{
|
||||
expiry: time.Now().Add(r.cacheTTL),
|
||||
expiry: time.Now().Add(ttl),
|
||||
host: host,
|
||||
}
|
||||
|
||||
|
||||
@@ -5,25 +5,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeRDNSExchanger is a mock [rdns.Exchanger] implementation for tests.
|
||||
type fakeRDNSExchanger struct {
|
||||
OnExchange func(ip netip.Addr) (host string, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ rdns.Exchanger = (*fakeRDNSExchanger)(nil)
|
||||
|
||||
// Exchange implements [rdns.Exchanger] interface for *fakeRDNSExchanger.
|
||||
func (e *fakeRDNSExchanger) Exchange(ip netip.Addr) (host string, err error) {
|
||||
return e.OnExchange(ip)
|
||||
}
|
||||
|
||||
func TestDefault_Process(t *testing.T) {
|
||||
ip1 := netip.MustParseAddr("1.2.3.4")
|
||||
revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice())
|
||||
@@ -67,21 +55,21 @@ func TestDefault_Process(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hit := 0
|
||||
onExchange := func(ip netip.Addr) (host string, err error) {
|
||||
onExchange := func(ip netip.Addr) (host string, ttl time.Duration, err error) {
|
||||
hit++
|
||||
|
||||
switch ip {
|
||||
case ip1:
|
||||
return revAddr1, nil
|
||||
return revAddr1, 0, nil
|
||||
case ip2:
|
||||
return revAddr2, nil
|
||||
return revAddr2, 0, nil
|
||||
case localIP:
|
||||
return localRevAddr1, nil
|
||||
return localRevAddr1, 0, nil
|
||||
default:
|
||||
return "", nil
|
||||
return "", 0, nil
|
||||
}
|
||||
}
|
||||
exchanger := &fakeRDNSExchanger{
|
||||
exchanger := &aghtest.Exchanger{
|
||||
OnExchange: onExchange,
|
||||
}
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ require (
|
||||
github.com/kisielk/errcheck v1.6.3
|
||||
github.com/kyoh86/looppointer v0.2.1
|
||||
github.com/securego/gosec/v2 v2.16.0
|
||||
github.com/uudashr/gocognit v1.0.6
|
||||
github.com/uudashr/gocognit v1.0.7
|
||||
golang.org/x/tools v0.11.0
|
||||
golang.org/x/vuln v0.2.0
|
||||
golang.org/x/vuln v1.0.0
|
||||
// TODO(a.garipov): Return to tagged releases once a new one appears.
|
||||
honnef.co/go/tools v0.5.0-0.dev.0.20230709092525-bc759185c5ee
|
||||
mvdan.cc/gofumpt v0.5.0
|
||||
@@ -22,12 +22,12 @@ require (
|
||||
github.com/BurntSushi/toml v1.3.2 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/gookit/color v1.5.3 // indirect
|
||||
github.com/gookit/color v1.5.4 // indirect
|
||||
github.com/kyoh86/nolint v0.0.1 // indirect
|
||||
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
|
||||
@@ -16,8 +16,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE
|
||||
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE=
|
||||
github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601 h1:mrEEilTAUmaAORhssPPkxj84TsHrPMLBGW2Z4SoTxm8=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
|
||||
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
|
||||
@@ -38,9 +38,9 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
|
||||
github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyOs5U=
|
||||
github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
|
||||
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842Y=
|
||||
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/uudashr/gocognit v1.0.7 h1:e9aFXgKgUJrQ5+bs61zBigmj7bFJ/5cC6HmMahVzuDo=
|
||||
github.com/uudashr/gocognit v1.0.7/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -52,8 +52,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22 h1:e8iSCQYXZ4EB6q3kIfy2fgPFTvDbozqzRe4OuIOyrL4=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230711023510-fffb14384f22/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090 h1:qOYhjyK9OeXREdh7Zrta8JRvnmnFIzhkosQpp+852Ag=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230725093048-515e97ebf090/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
||||
@@ -98,8 +98,8 @@ golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E
|
||||
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
|
||||
golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8=
|
||||
golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8=
|
||||
golang.org/x/vuln v0.2.0 h1:Dlz47lW0pvPHU7tnb10S8vbMn9GnV2B6eyT7Tem5XBI=
|
||||
golang.org/x/vuln v0.2.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
|
||||
golang.org/x/vuln v1.0.0 h1:tYLAU3jD9LQr98Y+3el06lWyGMCnvzw06PIWP3LIy7g=
|
||||
golang.org/x/vuln v1.0.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -48,9 +49,8 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool)
|
||||
|
||||
// Config is the configuration structure for Default.
|
||||
type Config struct {
|
||||
// DialContext specifies the dial function for creating unencrypted TCP
|
||||
// connections.
|
||||
DialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
// DialContext is used to create TCP connections to WHOIS servers.
|
||||
DialContext aghnet.DialContextFunc
|
||||
|
||||
// ServerAddr is the address of the WHOIS server.
|
||||
ServerAddr string
|
||||
@@ -86,9 +86,8 @@ type Default struct {
|
||||
// resolve the same IP.
|
||||
cache gcache.Cache
|
||||
|
||||
// dialContext connects to a remote server resolving hostname using our own
|
||||
// DNS server and unecrypted TCP connection.
|
||||
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
// dialContext is used to create TCP connections to WHOIS servers.
|
||||
dialContext aghnet.DialContextFunc
|
||||
|
||||
// serverAddr is the address of the WHOIS server.
|
||||
serverAddr string
|
||||
@@ -215,7 +214,7 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(w.timeout))
|
||||
_ = conn.SetDeadline(time.Now().Add(w.timeout))
|
||||
_, err = io.WriteString(conn, target+"\r\n")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
@@ -310,7 +309,7 @@ func (w *Default) requestInfo(
|
||||
|
||||
kv, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("whois: quering about %q: %s", ip, err)
|
||||
log.Debug("whois: querying %q: %s", ip, err)
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
@@ -113,20 +113,14 @@ func TestDefault_Process(t *testing.T) {
|
||||
|
||||
return copy(b, tc.data), io.EOF
|
||||
},
|
||||
OnWrite: func(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
},
|
||||
OnClose: func() (err error) {
|
||||
return nil
|
||||
},
|
||||
OnSetReadDeadline: func(t time.Time) (err error) {
|
||||
return nil
|
||||
},
|
||||
OnWrite: func(b []byte) (n int, err error) { return len(b), nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
OnSetDeadline: func(t time.Time) (err error) { return nil },
|
||||
}
|
||||
|
||||
w := whois.New(&whois.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
DialContext: func(_ context.Context, _, addr string) (_ net.Conn, _ error) {
|
||||
DialContext: func(_ context.Context, _, _ string) (_ net.Conn, _ error) {
|
||||
hit = 0
|
||||
|
||||
return fakeConn, nil
|
||||
|
||||
Reference in New Issue
Block a user