all: sync with master

This commit is contained in:
Ainar Garipov
2022-11-02 16:18:02 +03:00
parent 16755c37d8
commit c9314610d4
173 changed files with 11539 additions and 6928 deletions

View File

@@ -5,10 +5,11 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
)
// ARPDB: The Network Neighborhood Database
@@ -54,7 +55,7 @@ type Neighbor struct {
Name string
// IP contains either IPv4 or IPv6.
IP net.IP
IP netip.Addr
// MAC contains the hardware address.
MAC net.HardwareAddr
@@ -64,8 +65,8 @@ type Neighbor struct {
func (n Neighbor) Clone() (clone Neighbor) {
return Neighbor{
Name: n.Name,
IP: netutil.CloneIP(n.IP),
MAC: netutil.CloneMAC(n.MAC),
IP: n.IP,
MAC: slices.Clone(n.MAC),
}
}

View File

@@ -5,6 +5,7 @@ package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
@@ -47,22 +48,28 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
mac, err := net.ParseMAC(hwStr)
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
n.MAC = mac
}
host := fields[0]
if err := netutil.ValidateDomainName(host); err != nil {
log.Debug("parsing arp output: %s", err)
err = netutil.ValidateDomainName(host)
if err != nil {
log.Debug("arpdb: parsing arp output: host: %s", err)
} else {
n.Name = host
}

View File

@@ -4,6 +4,7 @@ package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -17,14 +18,14 @@ hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [
var wantNeighs = []Neighbor{{
Name: "hostname.one",
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
Name: "hostname.two",
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}, {
Name: "",
IP: net.ParseIP("::1234"),
IP: netip.MustParseAddr("::1234"),
MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
}}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io/fs"
"net"
"net/netip"
"strings"
"sync"
@@ -94,7 +95,8 @@ func (arp *fsysARPDB) Refresh() (err error) {
}
n := Neighbor{}
if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() {
n.IP, err = netip.ParseAddr(fields[0])
if err != nil || n.IP.IsUnspecified() {
continue
} else if n.MAC, err = net.ParseMAC(fields[3]); err != nil {
continue
@@ -135,15 +137,19 @@ func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() {
ip, err := netip.ParseAddr(fields[0])
if err != nil || n.IP.IsUnspecified() {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(hwStr)
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
@@ -174,7 +180,9 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
if ipStr := fields[1]; len(ipStr) < 2 {
continue
} else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil {
} else if ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1]); err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
@@ -182,7 +190,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
hwStr := fields[3]
if mac, err := net.ParseMAC(hwStr); err != nil {
log.Debug("parsing arp output: %s", err)
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {
@@ -191,7 +199,7 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
host := fields[0]
if verr := netutil.ValidateDomainName(host); verr != nil {
log.Debug("parsing arp output: %s", verr)
log.Debug("arpdb: parsing arp output: host: %s", verr)
} else {
n.Name = host
}
@@ -218,14 +226,18 @@ func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[4]); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(fields[4])
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {

View File

@@ -4,6 +4,7 @@ package aghnet
import (
"net"
"net/netip"
"sync"
"testing"
"testing/fstest"
@@ -33,10 +34,10 @@ const ipNeighOutput = `
::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -5,6 +5,7 @@ package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
@@ -50,14 +51,18 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
log.Debug("arpdb: parsing arp output: ip: %s", err)
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
log.Debug("parsing arp output: %s", err)
mac, err := net.ParseMAC(fields[1])
if err != nil {
log.Debug("arpdb: parsing arp output: mac: %s", err)
continue
} else {

View File

@@ -4,6 +4,7 @@ package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -15,9 +16,9 @@ Host Ethernet Address Netif Expire Flags
`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -2,6 +2,7 @@ package aghnet
import (
"net"
"net/netip"
"sync"
"testing"
@@ -35,7 +36,7 @@ func (arp *TestARPDB) Neighbors() (ns []Neighbor) {
}
func TestARPDBS(t *testing.T) {
knownIP := net.IP{1, 2, 3, 4}
knownIP := netip.MustParseAddr("1.2.3.4")
knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}
succRefrCount, failRefrCount := 0, 0

View File

@@ -5,6 +5,7 @@ package aghnet
import (
"bufio"
"net"
"net/netip"
"strings"
"sync"
)
@@ -43,13 +44,15 @@ func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
n := Neighbor{}
if ip := net.ParseIP(fields[0]); ip == nil {
ip, err := netip.ParseAddr(fields[0])
if err != nil {
continue
} else {
n.IP = ip
}
if mac, err := net.ParseMAC(fields[1]); err != nil {
mac, err := net.ParseMAC(fields[1])
if err != nil {
continue
} else {
n.MAC = mac

View File

@@ -4,6 +4,7 @@ package aghnet
import (
"net"
"net/netip"
)
const arpAOutput = `
@@ -14,9 +15,9 @@ Interface: 192.168.1.1 --- 0x7
::ffff:ffff ef-cd-ab-ef-cd-ab static`
var wantNeighs = []Neighbor{{
IP: net.IPv4(192, 168, 1, 2),
IP: netip.MustParseAddr("192.168.1.2"),
MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF},
}, {
IP: net.ParseIP("::ffff:ffff"),
IP: netip.MustParseAddr("::ffff:ffff"),
MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB},
}}

View File

@@ -6,6 +6,7 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"os"
"time"
@@ -38,48 +39,44 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
}
// ifaceIPv4Subnet returns the first suitable IPv4 subnetwork iface has.
func ifaceIPv4Subnet(iface *net.Interface) (subnet *net.IPNet, err error) {
func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) {
var addrs []net.Addr
if addrs, err = iface.Addrs(); err != nil {
return nil, err
return netip.Prefix{}, err
}
for _, a := range addrs {
var ip net.IP
var maskLen int
switch a := a.(type) {
case *net.IPAddr:
subnet = &net.IPNet{
IP: a.IP,
Mask: a.IP.DefaultMask(),
}
ip = a.IP
maskLen, _ = ip.DefaultMask().Size()
case *net.IPNet:
subnet = a
ip = a.IP
maskLen, _ = a.Mask.Size()
default:
continue
}
if ip4 := subnet.IP.To4(); ip4 != nil {
subnet.IP = ip4
return subnet, nil
if ip = ip.To4(); ip != nil {
return netip.PrefixFrom(netip.AddrFrom4(*(*[4]byte)(ip)), maskLen), nil
}
}
return nil, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
return netip.Prefix{}, fmt.Errorf("interface %s has no ipv4 addresses", iface.Name)
}
// checkOtherDHCPv4 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
var subnet *net.IPNet
var subnet netip.Prefix
if subnet, err = ifaceIPv4Subnet(iface); err != nil {
return false, err
}
// Resolve broadcast addr.
dst := netutil.IPPort{
IP: BroadcastFromIPNet(subnet),
Port: 67,
}.String()
dst := netip.AddrPortFrom(BroadcastFromPref(subnet), 67).String()
var dstAddr *net.UDPAddr
if dstAddr, err = net.ResolveUDPAddr("udp4", dst); err != nil {
return false, fmt.Errorf("couldn't resolve UDP address %s: %w", dst, err)

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"io"
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync"
@@ -19,6 +19,7 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
)
// DefaultHostsPaths returns the slice of paths default for the operating system
@@ -106,10 +107,10 @@ type HostsContainer struct {
done chan struct{}
// updates is the channel for receiving updated hosts.
updates chan *netutil.IPMap
updates chan HostsRecords
// last is the set of hosts that was cached within last detected change.
last *netutil.IPMap
last HostsRecords
// fsys is the working file system to read hosts files from.
fsys fs.FS
@@ -124,6 +125,25 @@ type HostsContainer struct {
listID int
}
// HostsRecords is a mapping of an IP address to its hosts data.
type HostsRecords map[netip.Addr]*HostsRecord
// HostsRecord represents a single hosts file record.
type HostsRecord struct {
Aliases *stringutil.Set
Canonical string
}
// equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
}
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
}
// ErrNoHostsPaths is returned when there are no valid paths to watch passed to
// the HostsContainer.
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
@@ -158,7 +178,7 @@ func NewHostsContainer(
},
listID: listID,
done: make(chan struct{}, 1),
updates: make(chan *netutil.IPMap, 1),
updates: make(chan HostsRecords, 1),
fsys: fsys,
w: w,
patterns: patterns,
@@ -196,9 +216,8 @@ func (hc *HostsContainer) Close() (err error) {
return nil
}
// Upd returns the channel into which the updates are sent. The receivable
// map's values are guaranteed to be of type of *HostsRecord.
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
// Upd returns the channel into which the updates are sent.
func (hc *HostsContainer) Upd() (updates <-chan HostsRecords) {
return hc.updates
}
@@ -264,7 +283,7 @@ type hostsParser struct {
// table stores only the unique IP-hostname pairs. It's also sent to the
// updates channel afterwards.
table *netutil.IPMap
table HostsRecords
}
// newHostsParser creates a new *hostsParser with buffers of size taken from the
@@ -273,7 +292,7 @@ func (hc *HostsContainer) newHostsParser() (hp *hostsParser) {
return &hostsParser{
rulesBuilder: &strings.Builder{},
translations: map[string]string{},
table: netutil.NewIPMap(hc.last.Len()),
table: make(HostsRecords, len(hc.last)),
}
}
@@ -285,7 +304,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
s := bufio.NewScanner(r)
for s.Scan() {
ip, hosts := hp.parseLine(s.Text())
if ip == nil || len(hosts) == 0 {
if ip == (netip.Addr{}) || len(hosts) == 0 {
continue
}
@@ -296,14 +315,15 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
}
// parseLine parses the line having the hosts syntax ignoring invalid ones.
func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
func (hp *hostsParser) parseLine(line string) (ip netip.Addr, hosts []string) {
fields := strings.Fields(line)
if len(fields) < 2 {
return nil, nil
return netip.Addr{}, nil
}
if ip = net.ParseIP(fields[0]); ip == nil {
return nil, nil
ip, err := netip.ParseAddr(fields[0])
if err != nil {
return netip.Addr{}, nil
}
for _, f := range fields[1:] {
@@ -321,7 +341,7 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
// See https://github.com/AdguardTeam/AdGuardHome/issues/3946.
//
// TODO(e.burkov): Investigate if hosts may contain DNS-SD domains.
err := netutil.ValidateDomainName(f)
err = netutil.ValidateDomainName(f)
if err != nil {
log.Error("%s: host %q is invalid, ignoring", hostsContainerPref, f)
@@ -334,30 +354,13 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
return ip, hosts
}
// HostsRecord represents a single hosts file record.
type HostsRecord struct {
Aliases *stringutil.Set
Canonical string
}
// Equal returns true if all fields of rec are equal to field in other or they
// both are nil.
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
if rec == nil {
return other == nil
}
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
}
// addRecord puts the record for the IP address to the rules builder if needed.
// The first host is considered to be the canonical name for the IP address.
// hosts must have at least one name.
func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
func (hp *hostsParser) addRecord(ip netip.Addr, hosts []string) {
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
var rec *HostsRecord
v, ok := hp.table.Get(ip)
rec, ok := hp.table[ip]
if !ok {
rec = &HostsRecord{
Aliases: stringutil.NewSet(),
@@ -365,14 +368,7 @@ func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
rec.Canonical, hosts = hosts[0], hosts[1:]
hp.addRules(ip, rec.Canonical, line)
hp.table.Set(ip, rec)
} else {
rec, ok = v.(*HostsRecord)
if !ok {
log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v)
return
}
hp.table[ip] = rec
}
for _, host := range hosts {
@@ -387,7 +383,7 @@ func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
}
// addRules adds rules and rule translations for the line.
func (hp *hostsParser) addRules(ip net.IP, host, line string) {
func (hp *hostsParser) addRules(ip netip.Addr, host, line string) {
rule, rulePtr := hp.writeRules(host, ip)
hp.translations[rule], hp.translations[rulePtr] = line, line
@@ -396,8 +392,9 @@ func (hp *hostsParser) addRules(ip net.IP, host, line string) {
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
// pair into internal builders.
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
arpa, err := netutil.IPToReversedAddr(ip)
func (hp *hostsParser) writeRules(host string, ip netip.Addr) (rule, rulePtr string) {
// TODO(a.garipov): Add a netip.Addr version to netutil.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil {
return "", ""
}
@@ -415,7 +412,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
var qtype string
// The validation of the IP address has been performed earlier so it is
// guaranteed to be either an IPv4 or an IPv6.
if ip.To4() != nil {
if ip.Is4() {
qtype = "A"
} else {
qtype = "AAAA"
@@ -442,51 +439,8 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
return rule, rulePtr
}
// equalSet returns true if the internal hosts table just parsed equals target.
// target's values must be of type *HostsRecord.
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
if target == nil {
// hp.table shouldn't appear nil since it's initialized on each refresh.
return target == hp.table
}
if hp.table.Len() != target.Len() {
return false
}
hp.table.Range(func(ip net.IP, recVal any) (cont bool) {
var targetVal any
targetVal, ok = target.Get(ip)
if !ok {
return false
}
var rec *HostsRecord
rec, ok = recVal.(*HostsRecord)
if !ok {
log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal)
return false
}
var targetRec *HostsRecord
targetRec, ok = targetVal.(*HostsRecord)
if !ok {
log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal)
return false
}
ok = rec.Equal(targetRec)
return ok
})
return ok
}
// sendUpd tries to send the parsed data to the ch.
func (hp *hostsParser) sendUpd(ch chan *netutil.IPMap) {
func (hp *hostsParser) sendUpd(ch chan HostsRecords) {
log.Debug("%s: sending upd", hostsContainerPref)
upd := hp.table
@@ -524,14 +478,14 @@ func (hc *HostsContainer) refresh() (err error) {
return fmt.Errorf("refreshing : %w", err)
}
if hp.equalSet(hc.last) {
if maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
log.Debug("%s: no changes detected", hostsContainerPref)
return nil
}
defer hp.sendUpd(hc.updates)
hc.last = hp.table.ShallowClone()
hc.last = maps.Clone(hp.table)
var rulesStrg *filterlist.RuleStorage
if rulesStrg, err = hp.newStrg(hc.listID); err != nil {

View File

@@ -3,6 +3,7 @@ package aghnet
import (
"io/fs"
"net"
"net/netip"
"path"
"strings"
"sync/atomic"
@@ -10,6 +11,7 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
@@ -135,7 +137,7 @@ func TestNewHostsContainer(t *testing.T) {
func TestHostsContainer_refresh(t *testing.T) {
// TODO(e.burkov): Test the case with no actual updates.
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}}
@@ -163,27 +165,17 @@ func TestHostsContainer_refresh(t *testing.T) {
checkRefresh := func(t *testing.T, want *HostsRecord) {
t.Helper()
var ok bool
var upd *netutil.IPMap
select {
case upd, ok = <-hc.Upd():
require.True(t, ok)
require.NotNil(t, upd)
case <-time.After(1 * time.Second):
t.Fatal("did not receive after 1s")
}
assert.Equal(t, 1, upd.Len())
v, ok := upd.Get(ip)
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
require.True(t, ok)
require.NotNil(t, upd)
require.IsType(t, (*HostsRecord)(nil), v)
assert.Len(t, upd, 1)
rec, _ := v.(*HostsRecord)
rec, ok := upd[ip]
require.True(t, ok)
require.NotNil(t, rec)
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
}
t.Run("initial_refresh", func(t *testing.T) {
@@ -568,13 +560,13 @@ func TestHostsContainer(t *testing.T) {
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP net.IP
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
@@ -589,7 +581,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "invalid_line",
line: ipStr,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
@@ -604,7 +596,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
@@ -614,7 +606,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
}, {
name: "empty",
line: ``,
wantIP: nil,
wantIP: netip.Addr{},
wantHosts: nil,
}}
@@ -622,7 +614,7 @@ func TestUniqueRules_ParseLine(t *testing.T) {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.True(t, tc.wantIP.Equal(got))
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}

View File

@@ -7,12 +7,12 @@ import (
"fmt"
"io"
"net"
"net/netip"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// Variables and functions to substitute in tests.
@@ -31,6 +31,12 @@ var (
// the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }
// IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which
// errors.Is(err, ErrNoStaticIPInfo) is true.
@@ -47,26 +53,31 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
//
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed.
func GatewayIP(ifaceName string) (ip net.IP) {
func GatewayIP(ifaceName string) (ip netip.Addr) {
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
if err != nil {
log.Debug("%s", err)
return nil
return netip.Addr{}
} else if code != 0 {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
return nil
return netip.Addr{}
}
fields := bytes.Fields(out)
// The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || string(fields[0]) != "default" {
return nil
return netip.Addr{}
}
return net.ParseIP(string(fields[2]))
ip, err = netip.ParseAddr(string(fields[2]))
if err != nil {
return netip.Addr{}
}
return ip
}
// CanBindPrivilegedPorts checks if current process can bind to privileged
@@ -78,9 +89,9 @@ func CanBindPrivilegedPorts() (can bool, err error) {
// NetInterface represents an entry of network interfaces map.
type NetInterface struct {
// Addresses are the network interface addresses.
Addresses []net.IP `json:"ip_addresses,omitempty"`
Addresses []netip.Addr `json:"ip_addresses,omitempty"`
// Subnets are the IP networks for this network interface.
Subnets []*net.IPNet `json:"-"`
Subnets []netip.Prefix `json:"-"`
Name string `json:"name"`
HardwareAddr net.HardwareAddr `json:"hardware_address"`
Flags net.Flags `json:"flags"`
@@ -101,57 +112,79 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
})
}
func NetInterfaceFrom(iface *net.Interface) (niface *NetInterface, err error) {
niface = &NetInterface{
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
MTU: iface.MTU,
}
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
// Collect network interface addresses.
for _, addr := range addrs {
n, ok := addr.(*net.IPNet)
if !ok {
// Should be *net.IPNet, this is weird.
return nil, fmt.Errorf("expected %[2]s to be %[1]T, got %[2]T", n, addr)
} else if ip4 := n.IP.To4(); ip4 != nil {
n.IP = ip4
}
ip, ok := netip.AddrFromSlice(n.IP)
if !ok {
return nil, fmt.Errorf("bad address %s", n.IP)
}
ip = ip.Unmap()
if ip.IsLinkLocalUnicast() {
// Ignore link-local IPv4.
if ip.Is4() {
continue
}
ip = ip.WithZone(iface.Name)
}
ones, _ := n.Mask.Size()
p := netip.PrefixFrom(ip, ones)
niface.Addresses = append(niface.Addresses, ip)
niface.Subnets = append(niface.Subnets, p)
}
return niface, nil
}
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// WEB only we do not return link-local addresses here.
//
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
func GetValidNetInterfacesForWeb() (nifaces []*NetInterface, err error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
return nil, fmt.Errorf("getting interfaces: %w", err)
} else if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface")
return nil, errors.Error("no legible interfaces")
}
for _, iface := range ifaces {
var addrs []net.Addr
addrs, err = iface.Addrs()
for i := range ifaces {
var niface *NetInterface
niface, err = NetInterfaceFrom(&ifaces[i])
if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
netIface := &NetInterface{
MTU: iface.MTU,
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
}
// Collect network interface addresses.
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
// Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
}
// Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 {
netIfaces = append(netIfaces, netIface)
return nil, err
} else if len(niface.Addresses) != 0 {
// Discard interfaces with no addresses.
nifaces = append(nifaces, niface)
}
}
return netIfaces, nil
return nifaces, nil
}
// InterfaceByIP returns the name of the interface bound to ip.
@@ -160,7 +193,7 @@ func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
// IP address can be shared by multiple interfaces in some configurations.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func InterfaceByIP(ip net.IP) (ifaceName string) {
func InterfaceByIP(ip netip.Addr) (ifaceName string) {
ifaces, err := GetValidNetInterfacesForWeb()
if err != nil {
return ""
@@ -168,7 +201,7 @@ func InterfaceByIP(ip net.IP) (ifaceName string) {
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
if ip.Equal(addr) {
if ip == addr {
return iface.Name
}
}
@@ -177,15 +210,16 @@ func InterfaceByIP(ip net.IP) (ifaceName string) {
return ""
}
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// GetSubnet returns the subnet corresponding to the interface of zero prefix if
// the search fails.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet {
func GetSubnet(ifaceName string) (p netip.Prefix) {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
return nil
return p
}
for _, netIface := range netIfaces {
@@ -194,14 +228,14 @@ func GetSubnet(ifaceName string) *net.IPNet {
}
}
return nil
return p
}
// CheckPort checks if the port is available for binding. network is expected
// to be one of "udp" and "tcp".
func CheckPort(network string, ip net.IP, port int) (err error) {
func CheckPort(network string, ipp netip.AddrPort) (err error) {
var c io.Closer
addr := netutil.IPPort{IP: ip, Port: port}.String()
addr := ipp.String()
switch network {
case "tcp":
c, err = net.Listen(network, addr)
@@ -251,18 +285,23 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
return addrs, nil
}
// BroadcastFromIPNet calculates the broadcast IP address for n.
func BroadcastFromIPNet(n *net.IPNet) (dc net.IP) {
dc = netutil.CloneIP(n.IP)
mask := n.Mask
if mask == nil {
mask = dc.DefaultMask()
// BroadcastFromPref calculates the broadcast IP address for p.
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
bc = p.Addr().Unmap()
if !bc.IsValid() {
return netip.Addr{}
}
for i, b := range mask {
dc[i] |= ^b
maskLen, addrLen := p.Bits(), bc.BitLen()
if maskLen == addrLen {
return bc
}
return dc
ipBytes := bc.AsSlice()
for i := maskLen; i < addrLen; i++ {
ipBytes[i/8] |= 1 << (7 - (i % 8))
}
bc, _ = netip.AddrFromSlice(ipBytes)
return bc
}

View File

@@ -6,7 +6,7 @@ import (
"bufio"
"fmt"
"io"
"net"
"net/netip"
"os"
"strings"
@@ -151,7 +151,7 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
// interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil {
if !ipNet.Addr().IsValid() {
return errors.Error("can't get IP address")
}
@@ -174,7 +174,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
func dhcpcdConfIface(ifaceName string, subnet netip.Prefix, gateway netip.Addr) (conf string) {
b := &strings.Builder{}
stringutil.WriteToBuilder(
b,
@@ -183,15 +183,15 @@ func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf stri
" added by AdGuard Home.\ninterface ",
ifaceName,
"\nstatic ip_address=",
ipNet.String(),
subnet.String(),
"\n",
)
if gwIP != nil {
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
if gateway != (netip.Addr{}) {
stringutil.WriteToBuilder(b, "static routers=", gateway.String(), "\n")
}
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
stringutil.WriteToBuilder(b, "static domain_name_servers=", subnet.Addr().String(), "\n\n")
return b.String()
}

View File

@@ -6,11 +6,11 @@ import (
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
@@ -19,7 +19,7 @@ import (
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
testutil.DiscardLogOutput(m)
}
// testdata is the filesystem containing data for testing the package.
@@ -93,34 +93,29 @@ func TestGatewayIP(t *testing.T) {
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
name string
shell mapShell
want net.IP
want netip.Addr
name string
}{{
name: "success_v4",
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
name: "success_v6",
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0xFF, 0xFF,
},
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
name: "bad_output",
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil,
want: netip.Addr{},
name: "bad_output",
}, {
name: "err_runcmd",
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil,
want: netip.Addr{},
name: "err_runcmd",
}, {
name: "bad_code",
shell: theOnlyCmd(cmd, 1, "", nil),
want: nil,
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
@@ -150,65 +145,61 @@ func TestInterfaceByIP(t *testing.T) {
}
func TestBroadcastFromIPNet(t *testing.T) {
known6 := net.IP{
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
}
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
name string
subnet *net.IPNet
want net.IP
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
subnet: &net.IPNet{
IP: net.IP{192, 168, 0, 1},
Mask: net.IPMask{255, 255, 15, 0},
},
want: net.IP{192, 168, 240, 255},
}, {
name: "ipv6_no_mask",
subnet: &net.IPNet{
IP: known6,
},
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
subnet: &net.IPNet{
IP: net.IP{192, 168, 1, 2},
},
want: net.IP{192, 168, 1, 255},
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
subnet: &net.IPNet{
IP: net.IP{0, 0, 0, 0},
Mask: net.IPMask{0, 0, 0, 0},
},
want: net.IPv4bcast,
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
bc := BroadcastFromIPNet(tc.subnet)
assert.True(t, bc.Equal(tc.want), bc)
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:")
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := netutil.IPPortFromAddr(l.Addr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp.IP, ipp.Port)
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
@@ -216,16 +207,15 @@ func TestCheckPort(t *testing.T) {
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", "127.0.0.1:")
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp.IP, ipp.Port)
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
@@ -233,12 +223,12 @@ func TestCheckPort(t *testing.T) {
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", nil, 0)
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
@@ -322,18 +312,18 @@ func TestNetInterface_MarshalJSON(t *testing.T) {
`"mtu":1500` +
`}` + "\n"
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []net.IP{ip4, ip6},
Subnets: []*net.IPNet{{
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,

View File

@@ -6,6 +6,7 @@ import (
"context"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -17,9 +18,8 @@ func createTestSystemResolversImpl(
t.Helper()
sr := createTestSystemResolvers(t, hostGenFunc)
require.IsType(t, (*systemResolvers)(nil), sr)
return sr.(*systemResolvers)
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
}
func TestSystemResolvers_Refresh(t *testing.T) {