all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-07-26 13:18:44 +03:00
parent ec83d0eb86
commit 48ee2f8a42
99 changed files with 3202 additions and 1886 deletions

View File

@@ -141,9 +141,9 @@ type HostsRecord struct {
Canonical string
}
// equal returns true if all fields of rec are equal to field in other or they
// Equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
} else if other == nil {
@@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) {
}
// hc.last is nil on the first refresh, so let that one through.
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) {
log.Debug("%s: no changes detected", hostsContainerPrefix)
return nil

View File

@@ -0,0 +1,144 @@
package aghnet
import (
"io/fs"
"net/netip"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const nl = "\n"
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &fakefs.StatFS{
OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") },
OnStat: func(name string) (fi fs.FileInfo, err error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@@ -1,9 +1,7 @@
package aghnet
package aghnet_test
import (
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync/atomic"
@@ -12,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
@@ -24,10 +23,7 @@ import (
"github.com/stretchr/testify/require"
)
const (
nl = "\n"
sp = " "
)
const nl = "\n"
func TestNewHostsContainer(t *testing.T) {
const dirname = "dir"
@@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) {
name: "one_file",
paths: []string{p},
}, {
wantErr: ErrNoHostsPaths,
wantErr: aghnet.ErrNoHostsPaths,
name: "no_files",
paths: []string{},
}, {
wantErr: ErrNoHostsPaths,
wantErr: aghnet.ErrNoHostsPaths,
name: "non-existent_file",
paths: []string{path.Join(dirname, filename+"2")},
}, {
@@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) {
return eventsCh
}
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
OnEvents: onEvents,
OnAdd: onAdd,
OnClose: func() (err error) { return nil },
@@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
_, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
@@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = NewHostsContainer(0, testFS, nil, p)
_, _ = aghnet.NewHostsContainer(0, testFS, nil, p)
})
})
@@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
@@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(0, testFS, w, "dir")
hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, want *HostsRecord) {
checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) {
t.Helper()
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
@@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) {
require.True(t, ok)
require.NotNil(t, rec)
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
}
t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, &HostsRecord{
checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet(),
Canonical: "hostname",
})
@@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{}
checkRefresh(t, &HostsRecord{
checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet("alias"),
Canonical: "hostname",
})
@@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) {
})
}
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &aghtest.StatFS{
OnStat: func(name string) (fs.FileInfo, error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestHostsContainer_Translate(t *testing.T) {
stubWatcher := aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil },
@@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
@@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil },
}
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
@@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) {
})
}
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@@ -3,6 +3,7 @@ package aghnet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -15,6 +16,10 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// DialContextFunc is the semantic alias for dialing functions, such as
// [http.Transport.DialContext].
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.

View File

@@ -5,9 +5,9 @@ import (
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
)
@@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
Data: []byte(`nameserver 1.1.1.1`),
},
}
panicFsys := &aghtest.FS{
panicFsys := &fakefs.FS{
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
}

View File

@@ -0,0 +1,334 @@
package aghnet
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe"
"github.com/google/renameio/v2/maybe"
"golang.org/x/sys/unix"
)

View File

@@ -1,21 +1,11 @@
package aghnet
package aghnet_test
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
@@ -24,315 +14,3 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}