all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-10-11 17:31:41 +03:00
parent 258eecc55b
commit 760d466b38
139 changed files with 39736 additions and 18364 deletions

View File

@@ -12,12 +12,12 @@ import (
// listenPacketReusable announces on the local network address additionally
// configuring the socket to have a reusable binding.
func listenPacketReusable(ifaceName, network, address string) (c net.PacketConn, err error) {
var port int
var port uint16
_, port, err = netutil.SplitHostPort(address)
if err != nil {
return nil, err
}
// TODO(e.burkov): Inspect nclient4.NewRawUDPConn and implement here.
return nclient4.NewRawUDPConn(ifaceName, port)
return nclient4.NewRawUDPConn(ifaceName, int(port))
}

View File

@@ -1,31 +0,0 @@
package aghnet
import (
"net"
)
// IpsetManager is the ipset manager interface.
//
// TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type,
// since ipset is exclusive to Linux?
type IpsetManager interface {
Add(host string, ip4s, ip6s []net.IP) (n int, err error)
Close() (err error)
}
// NewIpsetManager returns a new ipset. IPv4 addresses are added to an ipset
// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist.
//
// The syntax of the ipsetConf is:
//
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
//
// If ipsetConf is empty, msg and err are nil. The error is of type
// *aghos.UnsupportedError if the OS is not supported.
func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) {
if len(ipsetConf) == 0 {
return nil, nil
}
return newIpsetMgr(ipsetConf)
}

View File

@@ -1,394 +0,0 @@
//go:build linux
package aghnet
import (
"fmt"
"net"
"strings"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink"
"github.com/ti-mo/netfilter"
"golang.org/x/sys/unix"
)
// How to test on a real Linux machine:
//
// 1. Run "sudo ipset create example_set hash:ip family ipv4".
//
// 2. Run "sudo ipset list example_set". The Members field should be empty.
//
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
//
// 4. Start AdGuardHome.
//
// 5. Make requests to example.com and its subdomains.
//
// 6. Run "sudo ipset list example_set". The Members field should contain the
// resolved IP addresses.
// newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
return newIpsetMgrWithDialer(ipsetConf, defaultDial)
}
// defaultDial is the default netfilter dialing function.
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
conn, err = ipset.Dial(pf, conf)
if err != nil {
return nil, err
}
return conn, nil
}
// ipsetConn is the ipset conn interface.
type ipsetConn interface {
Add(name string, entries ...*ipset.Entry) (err error)
Close() (err error)
Header(name string) (p *ipset.HeaderPolicy, err error)
}
// ipsetDialer creates an ipsetConn.
type ipsetDialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error)
// ipsetProps contains one Linux Netfilter ipset properties.
type ipsetProps struct {
name string
family netfilter.ProtoFamily
}
// unit is a convenient alias for struct{}.
type unit = struct{}
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
type ipsInIpset map[ipInIpsetEntry]unit
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
type ipInIpsetEntry struct {
ipsetName string
ipArr [net.IPv6len]byte
}
// ipsetMgr is the Linux Netfilter ipset manager.
type ipsetMgr struct {
nameToIpset map[string]ipsetProps
domainToIpsets map[string][]ipsetProps
dial ipsetDialer
// mu protects all properties below.
mu *sync.Mutex
// TODO(a.garipov): Currently, the ipset list is static, and we don't
// read the IPs already in sets, so we can assume that all incoming IPs
// are either added to all corresponding ipsets or not. When that stops
// being the case, for example if we add dynamic reconfiguration of
// ipsets, this map will need to become a per-ipset-name one.
addedIPs ipsInIpset
ipv4Conn ipsetConn
ipv6Conn ipsetConn
}
// dialNetfilter establishes connections to Linux's netfilter module.
func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) {
// The kernel API does not actually require two sockets but package
// github.com/digineo/go-ipset does.
//
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and just
// use packages netfilter and netlink.
m.ipv4Conn, err = m.dial(netfilter.ProtoIPv4, conf)
if err != nil {
return fmt.Errorf("dialing v4: %w", err)
}
m.ipv6Conn, err = m.dial(netfilter.ProtoIPv6, conf)
if err != nil {
return fmt.Errorf("dialing v6: %w", err)
}
return nil
}
// parseIpsetConfig parses one ipset configuration string.
func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
confStr = strings.TrimSpace(confStr)
hostsAndNames := strings.Split(confStr, "/")
if len(hostsAndNames) != 2 {
return nil, nil, fmt.Errorf("invalid value %q: expected one slash", confStr)
}
hosts = strings.Split(hostsAndNames[0], ",")
ipsetNames = strings.Split(hostsAndNames[1], ",")
if len(ipsetNames) == 0 {
return nil, nil, nil
}
for i := range ipsetNames {
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
if len(ipsetNames[i]) == 0 {
return nil, nil, fmt.Errorf("invalid value %q: empty ipset name", confStr)
}
}
for i := range hosts {
hosts[i] = strings.ToLower(strings.TrimSpace(hosts[i]))
}
return hosts, ipsetNames, nil
}
// ipsetProps returns the properties of an ipset with the given name.
func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) {
// The family doesn't seem to matter when we use a header query, so
// query only the IPv4 one.
//
// TODO(a.garipov): Find out if this is a bug or a feature.
var res *ipset.HeaderPolicy
res, err = m.ipv4Conn.Header(name)
if err != nil {
return set, err
}
if res == nil || res.Family == nil {
return set, errors.Error("empty response or no family data")
}
family := netfilter.ProtoFamily(res.Family.Value)
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
return set, fmt.Errorf("unexpected ipset family %d", family)
}
return ipsetProps{
name: name,
family: family,
}, nil
}
// ipsets returns currently known ipsets.
func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) {
for _, name := range names {
set, ok := m.nameToIpset[name]
if ok {
sets = append(sets, set)
continue
}
set, err = m.ipsetProps(name)
if err != nil {
return nil, fmt.Errorf("querying ipset %q: %w", name, err)
}
m.nameToIpset[name] = set
sets = append(sets, set)
}
return sets, nil
}
// newIpsetMgrWithDialer returns a new Linux ipset manager using the provided
// dialer.
func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManager, err error) {
defer func() { err = errors.Annotate(err, "ipset: %w") }()
m := &ipsetMgr{
mu: &sync.Mutex{},
nameToIpset: make(map[string]ipsetProps),
domainToIpsets: make(map[string][]ipsetProps),
dial: dial,
addedIPs: make(ipsInIpset),
}
err = m.dialNetfilter(&netlink.Config{})
if err != nil {
if errors.Is(err, unix.EPROTONOSUPPORT) {
// The implementation doesn't support this protocol version. Just
// issue a warning.
log.Info("ipset: dialing netfilter: warning: %s", err)
return nil, nil
}
return nil, fmt.Errorf("dialing netfilter: %w", err)
}
for i, confStr := range ipsetConf {
var hosts, ipsetNames []string
hosts, ipsetNames, err = parseIpsetConfig(confStr)
if err != nil {
return nil, fmt.Errorf("config line at idx %d: %w", i, err)
}
var ipsets []ipsetProps
ipsets, err = m.ipsets(ipsetNames)
if err != nil {
return nil, fmt.Errorf(
"getting ipsets from config line at idx %d: %w",
i,
err,
)
}
for _, host := range hosts {
m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...)
}
}
return m, nil
}
// lookupHost find the ipsets for the host, taking subdomain wildcards into
// account.
func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) {
// Search for matching ipset hosts starting with most specific domain.
// We could use a trie here but the simple, inefficient solution isn't
// that expensive: ~10 ns for TLD + SLD vs. ~140 ns for 10 subdomains on
// an AMD Ryzen 7 PRO 4750U CPU; ~120 ns vs. ~ 1500 ns on a Raspberry
// Pi's ARMv7 rev 4 CPU.
for i := 0; ; i++ {
host = host[i:]
sets = m.domainToIpsets[host]
if sets != nil {
return sets
}
i = strings.Index(host, ".")
if i == -1 {
break
}
}
// Check the root catch-all one.
return m.domainToIpsets[""]
}
// addIPs adds the IP addresses for the host to the ipset. set must be same
// family as set's family.
func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err error) {
if len(ips) == 0 {
return 0, nil
}
var entries []*ipset.Entry
var newAddedEntries []ipInIpsetEntry
for _, ip := range ips {
e := ipInIpsetEntry{
ipsetName: set.name,
}
copy(e.ipArr[:], ip.To16())
if _, added := m.addedIPs[e]; added {
continue
}
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
newAddedEntries = append(newAddedEntries, e)
}
n = len(entries)
if n == 0 {
return 0, nil
}
var conn ipsetConn
switch set.family {
case netfilter.ProtoIPv4:
conn = m.ipv4Conn
case netfilter.ProtoIPv6:
conn = m.ipv6Conn
default:
return 0, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
}
err = conn.Add(set.name, entries...)
if err != nil {
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
}
// Only add these to the cache once we're sure that all of them were
// actually sent to the ipset.
for _, e := range newAddedEntries {
m.addedIPs[e] = unit{}
}
return n, nil
}
// addToSets adds the IP addresses to the corresponding ipset.
func (m *ipsetMgr) addToSets(
host string,
ip4s []net.IP,
ip6s []net.IP,
sets []ipsetProps,
) (n int, err error) {
for _, set := range sets {
var nn int
switch set.family {
case netfilter.ProtoIPv4:
nn, err = m.addIPs(host, set, ip4s)
if err != nil {
return n, err
}
case netfilter.ProtoIPv6:
nn, err = m.addIPs(host, set, ip6s)
if err != nil {
return n, err
}
default:
return n, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
}
log.Debug("ipset: added %d ips to set %s", nn, set.name)
n += nn
}
return n, nil
}
// Add implements the IpsetManager interface for *ipsetMgr
func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
sets := m.lookupHost(host)
if len(sets) == 0 {
return 0, nil
}
log.Debug("ipset: found %d sets", len(sets))
return m.addToSets(host, ip4s, ip6s, sets)
}
// Close implements the IpsetManager interface for *ipsetMgr.
func (m *ipsetMgr) Close() (err error) {
m.mu.Lock()
defer m.mu.Unlock()
var errs []error
// Close both and collect errors so that the errors from closing one
// don't interfere with closing the other.
err = m.ipv4Conn.Close()
if err != nil {
errs = append(errs, err)
}
err = m.ipv6Conn.Close()
if err != nil {
errs = append(errs, err)
}
return errors.Annotate(errors.Join(errs...), "closing ipsets: %w")
}

View File

@@ -1,154 +0,0 @@
//go:build linux
package aghnet
import (
"net"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ti-mo/netfilter"
)
// fakeIpsetConn is a fake ipsetConn for tests.
type fakeIpsetConn struct {
ipv4Header *ipset.HeaderPolicy
ipv4Entries *[]*ipset.Entry
ipv6Header *ipset.HeaderPolicy
ipv6Entries *[]*ipset.Entry
}
// Add implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) {
if strings.Contains(name, "ipv4") {
*c.ipv4Entries = append(*c.ipv4Entries, entries...)
return nil
} else if strings.Contains(name, "ipv6") {
*c.ipv6Entries = append(*c.ipv6Entries, entries...)
return nil
}
return errors.Error("test: ipset not found")
}
// Close implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Close() (err error) {
return nil
}
// Header implements the ipsetConn interface for *fakeIpsetConn.
func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
if strings.Contains(name, "ipv4") {
return c.ipv4Header, nil
} else if strings.Contains(name, "ipv6") {
return c.ipv6Header, nil
}
return nil, errors.Error("test: ipset not found")
}
func TestIpsetMgr_Add(t *testing.T) {
ipsetConf := []string{
"example.com,example.net/ipv4set",
"example.org,example.biz/ipv6set",
}
var ipv4Entries []*ipset.Entry
var ipv6Entries []*ipset.Entry
fakeDial := func(
pf netfilter.ProtoFamily,
conf *netlink.Config,
) (conn ipsetConn, err error) {
return &fakeIpsetConn{
ipv4Header: &ipset.HeaderPolicy{
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv4)),
},
ipv4Entries: &ipv4Entries,
ipv6Header: &ipset.HeaderPolicy{
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)),
},
ipv6Entries: &ipv6Entries,
}, nil
}
m, err := newIpsetMgrWithDialer(ipsetConf, fakeDial)
require.NoError(t, err)
ip4 := net.IP{1, 2, 3, 4}
ip6 := net.IP{
0x12, 0x34, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x56, 0x78,
}
n, err := m.Add("example.net", []net.IP{ip4}, nil)
require.NoError(t, err)
assert.Equal(t, 1, n)
require.Len(t, ipv4Entries, 1)
gotIP4 := ipv4Entries[0].IP.Value
assert.Equal(t, ip4, gotIP4)
n, err = m.Add("example.biz", nil, []net.IP{ip6})
require.NoError(t, err)
assert.Equal(t, 1, n)
require.Len(t, ipv6Entries, 1)
gotIP6 := ipv6Entries[0].IP.Value
assert.Equal(t, ip6, gotIP6)
err = m.Close()
assert.NoError(t, err)
}
var ipsetPropsSink []ipsetProps
func BenchmarkIpsetMgr_lookupHost(b *testing.B) {
propsLong := []ipsetProps{{
name: "example.com",
family: netfilter.ProtoIPv4,
}}
propsShort := []ipsetProps{{
name: "example.net",
family: netfilter.ProtoIPv4,
}}
m := &ipsetMgr{
domainToIpsets: map[string][]ipsetProps{
"": propsLong,
"example.net": propsShort,
},
}
b.Run("long", func(b *testing.B) {
const name = "a.very.long.domain.name.inside.the.domain.example.com"
for i := 0; i < b.N; i++ {
ipsetPropsSink = m.lookupHost(name)
}
require.Equal(b, propsLong, ipsetPropsSink)
})
b.Run("short", func(b *testing.B) {
const name = "example.net"
for i := 0; i < b.N; i++ {
ipsetPropsSink = m.lookupHost(name)
}
require.Equal(b, propsShort, ipsetPropsSink)
})
}

View File

@@ -1,11 +0,0 @@
//go:build !linux
package aghnet
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func newIpsetMgr(_ []string) (mgr IpsetManager, err error) {
return nil, aghos.Unsupported("ipset")
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"net/url"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@@ -263,7 +264,7 @@ func IsAddrInUse(err error) (ok bool) {
// CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) {
func CollectAllIfacesAddrs() (addrs []netip.Addr, err error) {
var ifaceAddrs []net.Addr
ifaceAddrs, err = netInterfaceAddrs()
if err != nil {
@@ -271,19 +272,41 @@ func CollectAllIfacesAddrs() (addrs []string, err error) {
}
for _, addr := range ifaceAddrs {
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
var p netip.Prefix
p, err = netip.ParsePrefix(addr.String())
if err != nil {
return nil, fmt.Errorf("parsing cidr: %w", err)
// Don't wrap the error since it's informative enough as is.
return nil, err
}
addrs = append(addrs, ip.String())
addrs = append(addrs, p.Addr())
}
return addrs, nil
}
// ParseAddrPort parses an [netip.AddrPort] from s, which should be either a
// valid IP, optionally with port, or a valid URL with plain IP address. The
// defaultPort is used if s doesn't contain port number.
func ParseAddrPort(s string, defaultPort uint16) (ipp netip.AddrPort, err error) {
u, err := url.Parse(s)
if err == nil && u.Host != "" {
s = u.Host
}
ipp, err = netip.ParseAddrPort(s)
if err != nil {
ip, parseErr := netip.ParseAddr(s)
if parseErr != nil {
return ipp, errors.Join(err, parseErr)
}
return netip.AddrPortFrom(ip, defaultPort), nil
}
return ipp, nil
}
// BroadcastFromPref calculates the broadcast IP address for p.
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
bc = p.Addr().Unmap()

View File

@@ -230,7 +230,7 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
wantAddrs []netip.Addr
}{{
name: "success",
wantErrMsg: ``,
@@ -241,10 +241,13 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
wantAddrs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("4.3.2.1"),
},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
wantErrMsg: `netip.ParsePrefix("1.2.3.4"): no '/'`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
@@ -269,12 +272,11 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
assert.ErrorIs(t, err, errAddrs)
})
}

View File

@@ -2,10 +2,16 @@ package aghnet_test
import (
"io/fs"
"net"
"net/netip"
"net/url"
"os"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
@@ -14,3 +20,76 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
func TestParseAddrPort(t *testing.T) {
const defaultPort = 1
v4addr := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
name string
input string
wantErrMsg string
want netip.AddrPort
}{{
name: "success_ip",
input: v4addr.String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, defaultPort),
}, {
name: "success_ip_port",
input: netutil.JoinHostPort(v4addr.String(), 5),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, 5),
}, {
name: "success_url",
input: (&url.URL{
Scheme: "tcp",
Host: v4addr.String(),
}).String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, defaultPort),
}, {
name: "success_url_port",
input: (&url.URL{
Scheme: "tcp",
Host: netutil.JoinHostPort(v4addr.String(), 5),
}).String(),
wantErrMsg: "",
want: netip.AddrPortFrom(v4addr, 5),
}, {
name: "error_invalid_ip",
input: "256.256.256.256",
wantErrMsg: `not an ip:port
ParseAddr("256.256.256.256"): IPv4 field has value >255`,
want: netip.AddrPort{},
}, {
name: "error_invalid_port",
input: net.JoinHostPort(v4addr.String(), "-5"),
wantErrMsg: `invalid port "-5" parsing "1.2.3.4:-5"
ParseAddr("1.2.3.4:-5"): unexpected character (at ":-5")`,
want: netip.AddrPort{},
}, {
name: "error_invalid_url",
input: "tcp:://1.2.3.4",
wantErrMsg: `invalid port "//1.2.3.4" parsing "tcp:://1.2.3.4"
ParseAddr("tcp:://1.2.3.4"): each colon-separated field must have at least ` +
`one digit (at "tcp:://1.2.3.4")`,
want: netip.AddrPort{},
}, {
name: "empty",
input: "",
want: netip.AddrPort{},
wantErrMsg: `not an ip:port
ParseAddr(""): unable to parse IP`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ap, err := aghnet.ParseAddrPort(tc.input, defaultPort)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, ap)
})
}
}

View File

@@ -1,36 +0,0 @@
package aghnet
// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
// const DefaultRefreshIvl = 5 * time.Minute
// HostGenFunc is the signature for functions generating fake hostnames. The
// implementation must be safe for concurrent use.
type HostGenFunc func() (host string)
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
type SystemResolvers interface {
// Get returns the slice of local resolvers' addresses. It must be safe for
// concurrent use.
Get() (rs []string)
// refresh refreshes the local resolvers' addresses cache. It must be safe
// for concurrent use.
refresh() (err error)
}
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
// nil is passed for hostGenFunc, the default generator will be used.
func NewSystemResolvers(
hostGenFunc HostGenFunc,
) (sr SystemResolvers, err error) {
sr = newSystemResolvers(hostGenFunc)
// Fill cache.
err = sr.refresh()
if err != nil {
return nil, err
}
return sr, nil
}

View File

@@ -1,146 +0,0 @@
//go:build !windows
package aghnet
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
)
// defaultHostGen is the default method of generating host for Refresh.
func defaultHostGen() (host string) {
// TODO(e.burkov): Use strings.Builder.
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
}
// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
// addrsLock protects addrs.
addrsLock sync.RWMutex
// addrs is the set that contains cached local resolvers' addresses.
addrs *stringutil.Set
// resolver is used to fetch the resolvers' addresses.
resolver *net.Resolver
// hostGenFunc generates hosts to resolve.
hostGenFunc HostGenFunc
}
const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
// refresh implements the SystemResolvers interface for *systemResolvers.
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
return nil
}
return err
}
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}
s := &systemResolvers{
resolver: &net.Resolver{
PreferGo: true,
},
hostGenFunc: hostGenFunc,
addrs: stringutil.NewSet(),
}
s.resolver.Dial = s.dialFunc
return s
}
// validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) {
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
// host
case 2:
// Remove the zone and check the IP address part.
host = parts[0]
default:
return errUnexpectedHostFormat
}
if _, err = netutil.ParseIP(host); err != nil {
return errBadAddrPassed
}
return nil
}
// dockerEmbeddedDNS is the address of Docker's embedded DNS server.
//
// See
// https://github.com/moby/moby/blob/v1.12.0/docs/userguide/networking/dockernetworks.md.
const dockerEmbeddedDNS = "127.0.0.11"
// dialFunc gets the resolver's address and puts it into internal cache.
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = netutil.SplitHost(address)
if err != nil {
// TODO(e.burkov): Maybe use a structured errBadAddrPassed to
// allow unwrapping of the real error.
return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
}
// Exclude Docker's embedded DNS server, as it may cause recursion if
// the container is set as the host system's default DNS server.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3064.
//
// TODO(a.garipov): Perhaps only do this when we are in the container?
// Maybe use an environment variable?
if host == dockerEmbeddedDNS {
return nil, errFakeDial
}
err = validateDialedHost(host)
if err != nil {
return nil, fmt.Errorf("validating dialed host: %w", err)
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs.Add(host)
return nil, errFakeDial
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
return sr.addrs.Values()
}

View File

@@ -1,81 +0,0 @@
//go:build !windows
package aghnet
import (
"context"
"testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolversImpl(
t *testing.T,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()
sr := createTestSystemResolvers(t, hostGenFunc)
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
}
func TestSystemResolvers_Refresh(t *testing.T) {
t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, nil)
assert.NoError(t, sr.refresh())
})
t.Run("unexpected_error", func(t *testing.T) {
_, err := NewSystemResolvers(func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
})
}
func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImpl(t, nil)
testCases := []struct {
want error
name string
address string
}{{
want: errFakeDial,
name: "valid_ipv4",
address: "127.0.0.1",
}, {
want: errFakeDial,
name: "valid_ipv6_port",
address: "[::1]:53",
}, {
want: errFakeDial,
name: "valid_ipv6_zone_port",
address: "[::1%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_split_host",
address: "127.0.0.1::123",
}, {
want: errUnexpectedHostFormat,
name: "invalid_ipv6_zone_port",
address: "[::1%%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_parse_ip",
address: "not-ip",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)
require.Nil(t, conn)
assert.ErrorIs(t, err, tc.want)
})
}
}

View File

@@ -1,37 +0,0 @@
package aghnet
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func createTestSystemResolvers(
t *testing.T,
hostGenFunc HostGenFunc,
) (sr SystemResolvers) {
t.Helper()
var err error
sr, err = NewSystemResolvers(hostGenFunc)
require.NoError(t, err)
require.NotNil(t, sr)
return sr
}
func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, nil)
var rs []string
require.NotPanics(t, func() {
rs = sr.Get()
})
assert.NotEmpty(t, rs)
}
// TODO(e.burkov): Write tests for refreshWithTicker.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.

View File

@@ -1,163 +0,0 @@
//go:build windows
package aghnet
import (
"bufio"
"fmt"
"io"
"net"
"os/exec"
"strings"
"sync"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// systemResolvers implementation differs for Windows since Go's resolver
// doesn't work there.
//
// See https://github.com/golang/go/issues/33097.
type systemResolvers struct {
// addrs is the slice of cached local resolvers' addresses.
addrs []string
addrsLock sync.RWMutex
}
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
return &systemResolvers{}
}
func (sr *systemResolvers) Get() (rs []string) {
sr.addrsLock.RLock()
defer sr.addrsLock.RUnlock()
addrs := sr.addrs
rs = make([]string, len(addrs))
copy(rs, addrs)
return rs
}
// writeExit writes "exit" to w and closes it. It is supposed to be run in
// a goroutine.
func writeExit(w io.WriteCloser) {
defer log.OnPanic("systemResolvers: writeExit")
defer func() {
derr := w.Close()
if derr != nil {
log.Error("systemResolvers: writeExit: closing: %s", derr)
}
}()
_, err := io.WriteString(w, "exit")
if err != nil {
log.Error("systemResolvers: writeExit: writing: %s", err)
}
}
// scanAddrs scans the DNS addresses from nslookup's output. The expected
// output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
func scanAddrs(s *bufio.Scanner) (addrs []string) {
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "Address:" {
continue
}
// If the address contains port then it is separated with '#'.
ipPort := strings.Split(fields[1], "#")
if len(ipPort) == 0 {
continue
}
addr := ipPort[0]
if net.ParseIP(addr) == nil {
log.Debug("systemResolvers: %q is not a valid ip", addr)
continue
}
addrs = append(addrs, addr)
}
return addrs
}
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
//
// TODO(e.burkov): This whole function needs more detailed research on getting
// local resolvers addresses on Windows. We execute the external command for
// now that is not the most accurate way.
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
var cmdPath string
cmdPath, err = exec.LookPath("nslookup.exe")
if err != nil {
return nil, fmt.Errorf("looking up cmd path: %w", err)
}
cmd := exec.Command(cmdPath)
var stdin io.WriteCloser
stdin, err = cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
}
var stdout io.ReadCloser
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
go writeExit(stdin)
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("start command executing: %w", err)
}
s := bufio.NewScanner(stdout)
addrs = scanAddrs(s)
err = cmd.Wait()
if err != nil {
return nil, fmt.Errorf("executing the command: %w", err)
}
err = s.Err()
if err != nil {
return nil, fmt.Errorf("scanning output: %w", err)
}
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
//
// See go doc os/exec.Cmd.StdoutPipe.
return addrs, nil
}
func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
got, err := sr.getAddrs()
if err != nil {
return fmt.Errorf("can't get addresses: %w", err)
}
if len(got) == 0 {
return nil
}
sr.addrsLock.Lock()
defer sr.addrsLock.Unlock()
sr.addrs = got
return nil
}

View File

@@ -1,7 +0,0 @@
//go:build windows
package aghnet
// TODO(e.burkov): Write tests for Windows implementation.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.