Merge branch 'master' into 2476-rwmutex
This commit is contained in:
@@ -29,10 +29,12 @@ type AutoHosts struct {
|
||||
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
|
||||
tableReverse map[string][]string
|
||||
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
watcher *fsnotify.Watcher // file and directory watcher object
|
||||
updateChan chan bool // signal for 'updateLoop' goroutine
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
watcher *fsnotify.Watcher // file and directory watcher object
|
||||
|
||||
// onlyWritesChan used to contain only writing events from watcher.
|
||||
onlyWritesChan chan fsnotify.Event
|
||||
|
||||
onChanged onChangedT // notification to other modules
|
||||
}
|
||||
@@ -54,7 +56,7 @@ func (a *AutoHosts) notify() {
|
||||
// hostsFn: Override default name for the hosts-file (optional)
|
||||
func (a *AutoHosts) Init(hostsFn string) {
|
||||
a.table = make(map[string][]net.IP)
|
||||
a.updateChan = make(chan bool, 2)
|
||||
a.onlyWritesChan = make(chan fsnotify.Event, 2)
|
||||
|
||||
a.hostsFn = "/etc/hosts"
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -82,8 +84,7 @@ func (a *AutoHosts) Init(hostsFn string) {
|
||||
func (a *AutoHosts) Start() {
|
||||
log.Debug("Start AutoHosts module")
|
||||
|
||||
go a.updateLoop()
|
||||
a.updateChan <- true
|
||||
a.updateHosts()
|
||||
|
||||
if a.watcher != nil {
|
||||
go a.watcherLoop()
|
||||
@@ -104,11 +105,10 @@ func (a *AutoHosts) Start() {
|
||||
|
||||
// Close - close module
|
||||
func (a *AutoHosts) Close() {
|
||||
a.updateChan <- false
|
||||
close(a.updateChan)
|
||||
if a.watcher != nil {
|
||||
_ = a.watcher.Close()
|
||||
}
|
||||
close(a.onlyWritesChan)
|
||||
}
|
||||
|
||||
// Process returns the list of IP addresses for the hostname or nil if nothing
|
||||
@@ -273,20 +273,32 @@ func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string
|
||||
}
|
||||
}
|
||||
|
||||
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
|
||||
func (a *AutoHosts) onlyWrites() {
|
||||
for event := range a.watcher.Events {
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
a.onlyWritesChan <- event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Receive notifications from fsnotify package
|
||||
func (a *AutoHosts) watcherLoop() {
|
||||
go a.onlyWrites()
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-a.watcher.Events:
|
||||
case event, ok := <-a.onlyWritesChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Assume that we sometimes have the same event occurred
|
||||
// several times.
|
||||
repeat := true
|
||||
for repeat {
|
||||
select {
|
||||
case <-a.watcher.Events:
|
||||
// Skip this duplicating event
|
||||
case _, ok = <-a.onlyWritesChan:
|
||||
repeat = ok
|
||||
default:
|
||||
repeat = false
|
||||
}
|
||||
@@ -294,12 +306,7 @@ func (a *AutoHosts) watcherLoop() {
|
||||
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Debug("AutoHosts: modified: %s", event.Name)
|
||||
select {
|
||||
case a.updateChan <- true:
|
||||
// sent a signal to 'updateLoop' goroutine
|
||||
default:
|
||||
// queue is full
|
||||
}
|
||||
a.updateHosts()
|
||||
}
|
||||
|
||||
case err, ok := <-a.watcher.Errors:
|
||||
@@ -311,18 +318,6 @@ func (a *AutoHosts) watcherLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// updateLoop reads static hosts from system files.
|
||||
func (a *AutoHosts) updateLoop() {
|
||||
for ok := range a.updateChan {
|
||||
if !ok {
|
||||
log.Debug("Finished AutoHosts update loop")
|
||||
return
|
||||
}
|
||||
|
||||
a.updateHosts()
|
||||
}
|
||||
}
|
||||
|
||||
// updateHosts - loads system hosts
|
||||
func (a *AutoHosts) updateHosts() {
|
||||
table := make(map[string][]net.IP)
|
||||
|
||||
@@ -8,117 +8,165 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func prepareTestDir() string {
|
||||
const dir = "./agh-test"
|
||||
_ = os.RemoveAll(dir)
|
||||
_ = os.MkdirAll(dir, 0o755)
|
||||
return dir
|
||||
func prepareTestFile(t *testing.T) (f *os.File) {
|
||||
t.Helper()
|
||||
|
||||
dir := aghtest.PrepareTestDir(t)
|
||||
|
||||
f, err := ioutil.TempFile(dir, "")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, f)
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, f.Close())
|
||||
})
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func assertWriting(t *testing.T, f *os.File, strs ...string) {
|
||||
t.Helper()
|
||||
|
||||
for _, str := range strs {
|
||||
n, err := f.WriteString(str)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, n, len(str))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoHostsResolution(t *testing.T) {
|
||||
ah := AutoHosts{}
|
||||
ah := &AutoHosts{}
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
f, _ := ioutil.TempFile(dir, "")
|
||||
defer func() { _ = os.Remove(f.Name()) }()
|
||||
defer f.Close()
|
||||
|
||||
_, _ = f.WriteString(" 127.0.0.1 host localhost # comment \n")
|
||||
_, _ = f.WriteString(" ::1 localhost#comment \n")
|
||||
f := prepareTestFile(t)
|
||||
|
||||
assertWriting(t, f,
|
||||
" 127.0.0.1 host localhost # comment \n",
|
||||
" ::1 localhost#comment \n",
|
||||
)
|
||||
ah.Init(f.Name())
|
||||
|
||||
// Existing host
|
||||
ips := ah.Process("localhost", dns.TypeA)
|
||||
assert.NotNil(t, ips)
|
||||
assert.Equal(t, 1, len(ips))
|
||||
assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0])
|
||||
t.Run("existing_host", func(t *testing.T) {
|
||||
ips := ah.Process("localhost", dns.TypeA)
|
||||
require.Len(t, ips, 1)
|
||||
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
|
||||
})
|
||||
|
||||
// Unknown host
|
||||
ips = ah.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ah.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
|
||||
// Unknown host (comment)
|
||||
ips = ah.Process("comment", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
// Comment.
|
||||
ips = ah.Process("comment", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
// Test hosts file
|
||||
table := ah.List()
|
||||
names, ok := table["127.0.0.1"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
t.Run("hosts_file", func(t *testing.T) {
|
||||
names, ok := ah.List()["127.0.0.1"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
})
|
||||
|
||||
// Test PTR
|
||||
a, _ := dns.ReverseAddr("127.0.0.1")
|
||||
a = strings.TrimSuffix(a, ".")
|
||||
hosts := ah.ProcessReverse(a, dns.TypePTR)
|
||||
if assert.Len(t, hosts, 2) {
|
||||
assert.Equal(t, hosts[0], "host")
|
||||
}
|
||||
t.Run("ptr", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantIP string
|
||||
wantLen int
|
||||
wantHost string
|
||||
}{
|
||||
{wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"},
|
||||
{wantIP: "::1", wantLen: 1, wantHost: "localhost"},
|
||||
}
|
||||
|
||||
a, _ = dns.ReverseAddr("::1")
|
||||
a = strings.TrimSuffix(a, ".")
|
||||
hosts = ah.ProcessReverse(a, dns.TypePTR)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, hosts[0], "localhost")
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
a, err := dns.ReverseAddr(tc.wantIP)
|
||||
require.Nil(t, err)
|
||||
|
||||
a = strings.TrimSuffix(a, ".")
|
||||
hosts := ah.ProcessReverse(a, dns.TypePTR)
|
||||
require.Len(t, hosts, tc.wantLen)
|
||||
assert.Equal(t, tc.wantHost, hosts[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAutoHostsFSNotify(t *testing.T) {
|
||||
ah := AutoHosts{}
|
||||
ah := &AutoHosts{}
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
f := prepareTestFile(t)
|
||||
|
||||
f, _ := ioutil.TempFile(dir, "")
|
||||
defer func() { _ = os.Remove(f.Name()) }()
|
||||
defer f.Close()
|
||||
|
||||
// Init
|
||||
_, _ = f.WriteString(" 127.0.0.1 host localhost \n")
|
||||
assertWriting(t, f, " 127.0.0.1 host localhost \n")
|
||||
ah.Init(f.Name())
|
||||
|
||||
// Unknown host
|
||||
ips := ah.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
t.Run("unknown_host", func(t *testing.T) {
|
||||
ips := ah.Process("newhost", dns.TypeA)
|
||||
assert.Nil(t, ips)
|
||||
})
|
||||
|
||||
// Stat monitoring for changes
|
||||
// Start monitoring for changes.
|
||||
ah.Start()
|
||||
defer ah.Close()
|
||||
t.Cleanup(ah.Close)
|
||||
|
||||
// Update file
|
||||
_, _ = f.WriteString("127.0.0.2 newhost\n")
|
||||
_ = f.Sync()
|
||||
assertWriting(t, f, "127.0.0.2 newhost\n")
|
||||
require.Nil(t, f.Sync())
|
||||
|
||||
// wait until fsnotify has triggerred and processed the file-modification event
|
||||
// Wait until fsnotify has triggerred and processed the
|
||||
// file-modification event.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check if we are notified about changes
|
||||
ips = ah.Process("newhost", dns.TypeA)
|
||||
assert.NotNil(t, ips)
|
||||
assert.Equal(t, 1, len(ips))
|
||||
assert.Equal(t, "127.0.0.2", ips[0].String())
|
||||
t.Run("notified", func(t *testing.T) {
|
||||
ips := ah.Process("newhost", dns.TypeA)
|
||||
assert.NotNil(t, ips)
|
||||
require.Len(t, ips, 1)
|
||||
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIP(t *testing.T) {
|
||||
assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String())
|
||||
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
|
||||
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
|
||||
func TestDNSReverseAddr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "good_ipv4",
|
||||
have: "1.0.0.127.in-addr.arpa",
|
||||
want: net.IP{127, 0, 0, 1},
|
||||
}, {
|
||||
name: "good_ipv6",
|
||||
have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
|
||||
want: net.ParseIP("::abcd:1234"),
|
||||
}, {
|
||||
name: "good_ipv6_case",
|
||||
have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
|
||||
want: net.ParseIP("::abcd:1234"),
|
||||
}, {
|
||||
name: "bad_ipv4_dot",
|
||||
have: "1.0.0.127.in-addr.arpa.",
|
||||
}, {
|
||||
name: "wrong_ipv4",
|
||||
have: ".0.0.127.in-addr.arpa",
|
||||
}, {
|
||||
name: "wrong_ipv6",
|
||||
have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
|
||||
}, {
|
||||
name: "bad_ipv6_dot",
|
||||
have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa",
|
||||
}, {
|
||||
name: "bad_ipv6_space",
|
||||
have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
|
||||
}}
|
||||
|
||||
assert.Nil(t, DNSUnreverseAddr("1.0.0.127.in-addr.arpa."))
|
||||
assert.Nil(t, DNSUnreverseAddr(".0.0.127.in-addr.arpa"))
|
||||
assert.Nil(t, DNSUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
|
||||
assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa"))
|
||||
assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := DNSUnreverseAddr(tc.have)
|
||||
assert.True(t, tc.want.Equal(ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
@@ -26,7 +28,7 @@ func ContainsString(strs []string, str string) bool {
|
||||
// FileExists returns true if file exists.
|
||||
func FileExists(fn string) bool {
|
||||
_, err := os.Stat(fn)
|
||||
return err == nil
|
||||
return err == nil || !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// RunCommand runs shell command.
|
||||
@@ -64,16 +66,43 @@ func SplitNext(str *string, splitBy byte) string {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
// IsOpenWRT checks if OS is OpenWRT.
|
||||
// IsOpenWRT returns true if host OS is OpenWRT.
|
||||
func IsOpenWRT() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadFile("/etc/os-release")
|
||||
const etcDir = "/etc"
|
||||
|
||||
// TODO(e.burkov): Take care of dealing with fs package after updating
|
||||
// Go version to 1.16.
|
||||
fileInfos, err := ioutil.ReadDir(etcDir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(string(body), "OpenWrt")
|
||||
// fNameSubstr is a part of a name of the desired file.
|
||||
const fNameSubstr = "release"
|
||||
osNameData := []byte("OpenWrt")
|
||||
|
||||
for _, fileInfo := range fileInfos {
|
||||
if fileInfo.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(fileInfo.Name(), fNameSubstr) {
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadFile(filepath.Join(etcDir, fileInfo.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Contains(body, osNameData) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSplitNext(t *testing.T) {
|
||||
s := " a,b , c "
|
||||
assert.True(t, SplitNext(&s, ',') == "a")
|
||||
assert.True(t, SplitNext(&s, ',') == "b")
|
||||
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
|
||||
|
||||
assert.Equal(t, "a", SplitNext(&s, ','))
|
||||
assert.Equal(t, "b", SplitNext(&s, ','))
|
||||
assert.Equal(t, "c", SplitNext(&s, ','))
|
||||
require.Empty(t, s)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -13,14 +14,30 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// NetInterface represents a list of network interfaces
|
||||
// NetInterface represents an entry of network interfaces map.
|
||||
type NetInterface struct {
|
||||
Name string // Network interface name
|
||||
MTU int // MTU
|
||||
HardwareAddr string // Hardware address
|
||||
Addresses []string // Array with the network interface addresses
|
||||
Subnets []string // Array with CIDR addresses of this network interface
|
||||
Flags string // Network interface flags (up, broadcast, etc)
|
||||
MTU int `json:"mtu"`
|
||||
Name string `json:"name"`
|
||||
HardwareAddr net.HardwareAddr `json:"hardware_address"`
|
||||
Flags net.Flags `json:"flags"`
|
||||
// Array with the network interface addresses.
|
||||
Addresses []net.IP `json:"ip_addresses,omitempty"`
|
||||
// Array with IP networks for this network interface.
|
||||
Subnets []*net.IPNet `json:"-"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for *NetInterface.
|
||||
func (iface *NetInterface) MarshalJSON() ([]byte, error) {
|
||||
type netInterface NetInterface
|
||||
return json.Marshal(&struct {
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Flags string `json:"flags"`
|
||||
*netInterface
|
||||
}{
|
||||
HardwareAddr: iface.HardwareAddr.String(),
|
||||
Flags: iface.Flags.String(),
|
||||
netInterface: (*netInterface)(iface),
|
||||
})
|
||||
}
|
||||
|
||||
// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
|
||||
@@ -40,7 +57,7 @@ func GetValidNetInterfaces() ([]net.Interface, error) {
|
||||
|
||||
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only
|
||||
// we do not return link-local addresses here
|
||||
func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
|
||||
func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
|
||||
ifaces, err := GetValidNetInterfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
|
||||
@@ -49,7 +66,7 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
|
||||
return nil, errors.New("couldn't find any legible interface")
|
||||
}
|
||||
|
||||
var netInterfaces []NetInterface
|
||||
var netInterfaces []*NetInterface
|
||||
|
||||
for _, iface := range ifaces {
|
||||
addrs, err := iface.Addrs()
|
||||
@@ -57,32 +74,29 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
|
||||
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
|
||||
}
|
||||
|
||||
netIface := NetInterface{
|
||||
Name: iface.Name,
|
||||
netIface := &NetInterface{
|
||||
MTU: iface.MTU,
|
||||
HardwareAddr: iface.HardwareAddr.String(),
|
||||
Name: iface.Name,
|
||||
HardwareAddr: iface.HardwareAddr,
|
||||
Flags: iface.Flags,
|
||||
}
|
||||
|
||||
if iface.Flags != 0 {
|
||||
netIface.Flags = iface.Flags.String()
|
||||
}
|
||||
|
||||
// Collect network interface addresses
|
||||
// Collect network interface addresses.
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// not an IPNet, should not happen
|
||||
// Should be net.IPNet, this is weird.
|
||||
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
}
|
||||
// ignore link-local
|
||||
// Ignore link-local.
|
||||
if ipNet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
netIface.Addresses = append(netIface.Addresses, ipNet.IP.String())
|
||||
netIface.Subnets = append(netIface.Subnets, ipNet.String())
|
||||
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
|
||||
netIface.Subnets = append(netIface.Subnets, ipNet)
|
||||
}
|
||||
|
||||
// Discard interfaces with no addresses
|
||||
// Discard interfaces with no addresses.
|
||||
if len(netIface.Addresses) != 0 {
|
||||
netInterfaces = append(netInterfaces, netIface)
|
||||
}
|
||||
@@ -91,8 +105,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
|
||||
return netInterfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceByIP - Get interface name by its IP address.
|
||||
func GetInterfaceByIP(ip string) string {
|
||||
// GetInterfaceByIP returns the name of interface containing provided ip.
|
||||
func GetInterfaceByIP(ip net.IP) string {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -100,7 +114,7 @@ func GetInterfaceByIP(ip string) string {
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
if ip == addr {
|
||||
if ip.Equal(addr) {
|
||||
return iface.Name
|
||||
}
|
||||
}
|
||||
@@ -109,13 +123,13 @@ func GetInterfaceByIP(ip string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSubnet - Get IP address with netmask for the specified interface
|
||||
// Returns an empty string if it fails to find it
|
||||
func GetSubnet(ifaceName string) string {
|
||||
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
|
||||
// the search fails.
|
||||
func GetSubnet(ifaceName string) *net.IPNet {
|
||||
netIfaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
log.Error("Could not get network interfaces info: %v", err)
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, netIface := range netIfaces {
|
||||
@@ -124,12 +138,12 @@ func GetSubnet(ifaceName string) string {
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPortAvailable - check if TCP port is available
|
||||
func CheckPortAvailable(host string, port int) error {
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
func CheckPortAvailable(host net.IP, port int) error {
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -142,8 +156,8 @@ func CheckPortAvailable(host string, port int) error {
|
||||
}
|
||||
|
||||
// CheckPacketPortAvailable - check if UDP port is available
|
||||
func CheckPacketPortAvailable(host string, port int) error {
|
||||
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
func CheckPacketPortAvailable(host net.IP, port int) error {
|
||||
ln, err := net.ListenPacket("udp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,22 +2,15 @@ package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot get net interfaces: %s", err)
|
||||
}
|
||||
if len(ifaces) == 0 {
|
||||
t.Fatalf("No net interfaces found")
|
||||
}
|
||||
|
||||
require.Nilf(t, err, "Cannot get net interfaces: %s", err)
|
||||
require.NotEmpty(t, ifaces, "No net interfaces found")
|
||||
for _, iface := range ifaces {
|
||||
if len(iface.Addresses) == 0 {
|
||||
t.Fatalf("No addresses found for %s", iface.Name)
|
||||
}
|
||||
|
||||
t.Logf("%v", iface)
|
||||
require.NotEmptyf(t, iface.Addresses, "No addresses found for %s", iface.Name)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user