Merge branch 'master' into 2476-rwmutex

This commit is contained in:
Ainar Garipov
2021-03-03 16:10:29 +03:00
485 changed files with 35236 additions and 6572 deletions

View File

@@ -29,10 +29,12 @@ type AutoHosts struct {
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
tableReverse map[string][]string
hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files
watcher *fsnotify.Watcher // file and directory watcher object
updateChan chan bool // signal for 'updateLoop' goroutine
hostsFn string // path to the main hosts-file
hostsDirs []string // paths to OS-specific directories with hosts-files
watcher *fsnotify.Watcher // file and directory watcher object
// onlyWritesChan used to contain only writing events from watcher.
onlyWritesChan chan fsnotify.Event
onChanged onChangedT // notification to other modules
}
@@ -54,7 +56,7 @@ func (a *AutoHosts) notify() {
// hostsFn: Override default name for the hosts-file (optional)
func (a *AutoHosts) Init(hostsFn string) {
a.table = make(map[string][]net.IP)
a.updateChan = make(chan bool, 2)
a.onlyWritesChan = make(chan fsnotify.Event, 2)
a.hostsFn = "/etc/hosts"
if runtime.GOOS == "windows" {
@@ -82,8 +84,7 @@ func (a *AutoHosts) Init(hostsFn string) {
func (a *AutoHosts) Start() {
log.Debug("Start AutoHosts module")
go a.updateLoop()
a.updateChan <- true
a.updateHosts()
if a.watcher != nil {
go a.watcherLoop()
@@ -104,11 +105,10 @@ func (a *AutoHosts) Start() {
// Close - close module
func (a *AutoHosts) Close() {
a.updateChan <- false
close(a.updateChan)
if a.watcher != nil {
_ = a.watcher.Close()
}
close(a.onlyWritesChan)
}
// Process returns the list of IP addresses for the hostname or nil if nothing
@@ -273,20 +273,32 @@ func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string][]string
}
}
// onlyWrites is a filter for (*fsnotify.Watcher).Events.
func (a *AutoHosts) onlyWrites() {
for event := range a.watcher.Events {
if event.Op&fsnotify.Write == fsnotify.Write {
a.onlyWritesChan <- event
}
}
}
// Receive notifications from fsnotify package
func (a *AutoHosts) watcherLoop() {
go a.onlyWrites()
for {
select {
case event, ok := <-a.watcher.Events:
case event, ok := <-a.onlyWritesChan:
if !ok {
return
}
// Assume that we sometimes have the same event occurred
// several times.
repeat := true
for repeat {
select {
case <-a.watcher.Events:
// Skip this duplicating event
case _, ok = <-a.onlyWritesChan:
repeat = ok
default:
repeat = false
}
@@ -294,12 +306,7 @@ func (a *AutoHosts) watcherLoop() {
if event.Op&fsnotify.Write == fsnotify.Write {
log.Debug("AutoHosts: modified: %s", event.Name)
select {
case a.updateChan <- true:
// sent a signal to 'updateLoop' goroutine
default:
// queue is full
}
a.updateHosts()
}
case err, ok := <-a.watcher.Errors:
@@ -311,18 +318,6 @@ func (a *AutoHosts) watcherLoop() {
}
}
// updateLoop reads static hosts from system files.
func (a *AutoHosts) updateLoop() {
for ok := range a.updateChan {
if !ok {
log.Debug("Finished AutoHosts update loop")
return
}
a.updateHosts()
}
}
// updateHosts - loads system hosts
func (a *AutoHosts) updateHosts() {
table := make(map[string][]net.IP)

View File

@@ -8,117 +8,165 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0o755)
return dir
func prepareTestFile(t *testing.T) (f *os.File) {
t.Helper()
dir := aghtest.PrepareTestDir(t)
f, err := ioutil.TempFile(dir, "")
require.Nil(t, err)
require.NotNil(t, f)
t.Cleanup(func() {
assert.Nil(t, f.Close())
})
return f
}
func assertWriting(t *testing.T, f *os.File, strs ...string) {
t.Helper()
for _, str := range strs {
n, err := f.WriteString(str)
require.Nil(t, err)
assert.Equal(t, n, len(str))
}
}
func TestAutoHostsResolution(t *testing.T) {
ah := AutoHosts{}
ah := &AutoHosts{}
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
f, _ := ioutil.TempFile(dir, "")
defer func() { _ = os.Remove(f.Name()) }()
defer f.Close()
_, _ = f.WriteString(" 127.0.0.1 host localhost # comment \n")
_, _ = f.WriteString(" ::1 localhost#comment \n")
f := prepareTestFile(t)
assertWriting(t, f,
" 127.0.0.1 host localhost # comment \n",
" ::1 localhost#comment \n",
)
ah.Init(f.Name())
// Existing host
ips := ah.Process("localhost", dns.TypeA)
assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips))
assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0])
t.Run("existing_host", func(t *testing.T) {
ips := ah.Process("localhost", dns.TypeA)
require.Len(t, ips, 1)
assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0])
})
// Unknown host
ips = ah.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
t.Run("unknown_host", func(t *testing.T) {
ips := ah.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
// Unknown host (comment)
ips = ah.Process("comment", dns.TypeA)
assert.Nil(t, ips)
// Comment.
ips = ah.Process("comment", dns.TypeA)
assert.Nil(t, ips)
})
// Test hosts file
table := ah.List()
names, ok := table["127.0.0.1"]
assert.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names)
t.Run("hosts_file", func(t *testing.T) {
names, ok := ah.List()["127.0.0.1"]
require.True(t, ok)
assert.Equal(t, []string{"host", "localhost"}, names)
})
// Test PTR
a, _ := dns.ReverseAddr("127.0.0.1")
a = strings.TrimSuffix(a, ".")
hosts := ah.ProcessReverse(a, dns.TypePTR)
if assert.Len(t, hosts, 2) {
assert.Equal(t, hosts[0], "host")
}
t.Run("ptr", func(t *testing.T) {
testCases := []struct {
wantIP string
wantLen int
wantHost string
}{
{wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"},
{wantIP: "::1", wantLen: 1, wantHost: "localhost"},
}
a, _ = dns.ReverseAddr("::1")
a = strings.TrimSuffix(a, ".")
hosts = ah.ProcessReverse(a, dns.TypePTR)
if assert.Len(t, hosts, 1) {
assert.Equal(t, hosts[0], "localhost")
}
for _, tc := range testCases {
a, err := dns.ReverseAddr(tc.wantIP)
require.Nil(t, err)
a = strings.TrimSuffix(a, ".")
hosts := ah.ProcessReverse(a, dns.TypePTR)
require.Len(t, hosts, tc.wantLen)
assert.Equal(t, tc.wantHost, hosts[0])
}
})
}
func TestAutoHostsFSNotify(t *testing.T) {
ah := AutoHosts{}
ah := &AutoHosts{}
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
f := prepareTestFile(t)
f, _ := ioutil.TempFile(dir, "")
defer func() { _ = os.Remove(f.Name()) }()
defer f.Close()
// Init
_, _ = f.WriteString(" 127.0.0.1 host localhost \n")
assertWriting(t, f, " 127.0.0.1 host localhost \n")
ah.Init(f.Name())
// Unknown host
ips := ah.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
t.Run("unknown_host", func(t *testing.T) {
ips := ah.Process("newhost", dns.TypeA)
assert.Nil(t, ips)
})
// Stat monitoring for changes
// Start monitoring for changes.
ah.Start()
defer ah.Close()
t.Cleanup(ah.Close)
// Update file
_, _ = f.WriteString("127.0.0.2 newhost\n")
_ = f.Sync()
assertWriting(t, f, "127.0.0.2 newhost\n")
require.Nil(t, f.Sync())
// wait until fsnotify has triggerred and processed the file-modification event
// Wait until fsnotify has triggerred and processed the
// file-modification event.
time.Sleep(50 * time.Millisecond)
// Check if we are notified about changes
ips = ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips))
assert.Equal(t, "127.0.0.2", ips[0].String())
t.Run("notified", func(t *testing.T) {
ips := ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
require.Len(t, ips, 1)
assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
})
}
func TestIP(t *testing.T) {
assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String())
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
func TestDNSReverseAddr(t *testing.T) {
testCases := []struct {
name string
have string
want net.IP
}{{
name: "good_ipv4",
have: "1.0.0.127.in-addr.arpa",
want: net.IP{127, 0, 0, 1},
}, {
name: "good_ipv6",
have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
want: net.ParseIP("::abcd:1234"),
}, {
name: "good_ipv6_case",
have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
want: net.ParseIP("::abcd:1234"),
}, {
name: "bad_ipv4_dot",
have: "1.0.0.127.in-addr.arpa.",
}, {
name: "wrong_ipv4",
have: ".0.0.127.in-addr.arpa",
}, {
name: "wrong_ipv6",
have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
}, {
name: "bad_ipv6_dot",
have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa",
}, {
name: "bad_ipv6_space",
have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa",
}}
assert.Nil(t, DNSUnreverseAddr("1.0.0.127.in-addr.arpa."))
assert.Nil(t, DNSUnreverseAddr(".0.0.127.in-addr.arpa"))
assert.Nil(t, DNSUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa"))
assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ip := DNSUnreverseAddr(tc.have)
assert.True(t, tc.want.Equal(ip))
})
}
}

View File

@@ -5,10 +5,12 @@
package util
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
)
@@ -26,7 +28,7 @@ func ContainsString(strs []string, str string) bool {
// FileExists returns true if file exists.
func FileExists(fn string) bool {
_, err := os.Stat(fn)
return err == nil
return err == nil || !os.IsNotExist(err)
}
// RunCommand runs shell command.
@@ -64,16 +66,43 @@ func SplitNext(str *string, splitBy byte) string {
return strings.TrimSpace(s)
}
// IsOpenWRT checks if OS is OpenWRT.
// IsOpenWRT returns true if host OS is OpenWRT.
func IsOpenWRT() bool {
if runtime.GOOS != "linux" {
return false
}
body, err := ioutil.ReadFile("/etc/os-release")
const etcDir = "/etc"
// TODO(e.burkov): Take care of dealing with fs package after updating
// Go version to 1.16.
fileInfos, err := ioutil.ReadDir(etcDir)
if err != nil {
return false
}
return strings.Contains(string(body), "OpenWrt")
// fNameSubstr is a part of a name of the desired file.
const fNameSubstr = "release"
osNameData := []byte("OpenWrt")
for _, fileInfo := range fileInfos {
if fileInfo.IsDir() {
continue
}
if !strings.Contains(fileInfo.Name(), fNameSubstr) {
continue
}
body, err := ioutil.ReadFile(filepath.Join(etcDir, fileInfo.Name()))
if err != nil {
continue
}
if bytes.Contains(body, osNameData) {
return true
}
}
return false
}

View File

@@ -4,11 +4,14 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSplitNext(t *testing.T) {
s := " a,b , c "
assert.True(t, SplitNext(&s, ',') == "a")
assert.True(t, SplitNext(&s, ',') == "b")
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
assert.Equal(t, "a", SplitNext(&s, ','))
assert.Equal(t, "b", SplitNext(&s, ','))
assert.Equal(t, "c", SplitNext(&s, ','))
require.Empty(t, s)
}

View File

@@ -1,6 +1,7 @@
package util
import (
"encoding/json"
"errors"
"fmt"
"net"
@@ -13,14 +14,30 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// NetInterface represents a list of network interfaces
// NetInterface represents an entry of network interfaces map.
type NetInterface struct {
Name string // Network interface name
MTU int // MTU
HardwareAddr string // Hardware address
Addresses []string // Array with the network interface addresses
Subnets []string // Array with CIDR addresses of this network interface
Flags string // Network interface flags (up, broadcast, etc)
MTU int `json:"mtu"`
Name string `json:"name"`
HardwareAddr net.HardwareAddr `json:"hardware_address"`
Flags net.Flags `json:"flags"`
// Array with the network interface addresses.
Addresses []net.IP `json:"ip_addresses,omitempty"`
// Array with IP networks for this network interface.
Subnets []*net.IPNet `json:"-"`
}
// MarshalJSON implements the json.Marshaler interface for *NetInterface.
func (iface *NetInterface) MarshalJSON() ([]byte, error) {
type netInterface NetInterface
return json.Marshal(&struct {
HardwareAddr string `json:"hardware_address"`
Flags string `json:"flags"`
*netInterface
}{
HardwareAddr: iface.HardwareAddr.String(),
Flags: iface.Flags.String(),
netInterface: (*netInterface)(iface),
})
}
// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
@@ -40,7 +57,7 @@ func GetValidNetInterfaces() ([]net.Interface, error) {
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only
// we do not return link-local addresses here
func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
ifaces, err := GetValidNetInterfaces()
if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
@@ -49,7 +66,7 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
return nil, errors.New("couldn't find any legible interface")
}
var netInterfaces []NetInterface
var netInterfaces []*NetInterface
for _, iface := range ifaces {
addrs, err := iface.Addrs()
@@ -57,32 +74,29 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
netIface := NetInterface{
Name: iface.Name,
netIface := &NetInterface{
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr.String(),
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
}
if iface.Flags != 0 {
netIface.Flags = iface.Flags.String()
}
// Collect network interface addresses
// Collect network interface addresses.
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
// Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
}
// ignore link-local
// Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP.String())
netIface.Subnets = append(netIface.Subnets, ipNet.String())
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses
// Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface)
}
@@ -91,8 +105,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
return netInterfaces, nil
}
// GetInterfaceByIP - Get interface name by its IP address.
func GetInterfaceByIP(ip string) string {
// GetInterfaceByIP returns the name of interface containing provided ip.
func GetInterfaceByIP(ip net.IP) string {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
return ""
@@ -100,7 +114,7 @@ func GetInterfaceByIP(ip string) string {
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
if ip == addr {
if ip.Equal(addr) {
return iface.Name
}
}
@@ -109,13 +123,13 @@ func GetInterfaceByIP(ip string) string {
return ""
}
// GetSubnet - Get IP address with netmask for the specified interface
// Returns an empty string if it fails to find it
func GetSubnet(ifaceName string) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails.
func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
return ""
return nil
}
for _, netIface := range netIfaces {
@@ -124,12 +138,12 @@ func GetSubnet(ifaceName string) string {
}
}
return ""
return nil
}
// CheckPortAvailable - check if TCP port is available
func CheckPortAvailable(host string, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
func CheckPortAvailable(host net.IP, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil {
return err
}
@@ -142,8 +156,8 @@ func CheckPortAvailable(host string, port int) error {
}
// CheckPacketPortAvailable - check if UDP port is available
func CheckPacketPortAvailable(host string, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
func CheckPacketPortAvailable(host net.IP, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil {
return err
}

View File

@@ -2,22 +2,15 @@ package util
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetValidNetInterfacesForWeb(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
t.Fatalf("Cannot get net interfaces: %s", err)
}
if len(ifaces) == 0 {
t.Fatalf("No net interfaces found")
}
require.Nilf(t, err, "Cannot get net interfaces: %s", err)
require.NotEmpty(t, ifaces, "No net interfaces found")
for _, iface := range ifaces {
if len(iface.Addresses) == 0 {
t.Fatalf("No addresses found for %s", iface.Name)
}
t.Logf("%v", iface)
require.NotEmptyf(t, iface.Addresses, "No addresses found for %s", iface.Name)
}
}