all: sync with master

This commit is contained in:
Ainar Garipov
2022-11-23 16:52:05 +03:00
parent 67fe064fcf
commit bd64b8b014
58 changed files with 410 additions and 220 deletions

View File

@@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
// different types (string, netip.Addr, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client
@@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
}
// exists checks if client with this IP address already exists.
func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) {
func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool
return true
}
rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if !ok {
return false
}
@@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
var artClient *querylog.Client
var art bool
for _, id := range ids {
c, art = clients.clientOrArtificial(net.ParseIP(id), id)
ip, _ := netip.ParseAddr(id)
c, art = clients.clientOrArtificial(ip, id)
if art {
artClient = c
@@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
// records about this client besides maybe whether or not it is blocked. c is
// never nil.
func (clients *clientsContainer) clientOrArtificial(
ip net.IP,
ip netip.Addr,
id string,
) (c *querylog.Client, art bool) {
defer func() {
@@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial(
}, false
}
if ip == nil {
// Technically should never happen, but still.
return &querylog.Client{
Name: "",
}, true
}
var rc *RuntimeClient
rc, ok = clients.findRuntimeClient(ip)
if ok {
@@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return c, true
}
ip := net.ParseIP(id)
if ip == nil {
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
_, ipnet, err := net.ParseCIDR(id)
var n netip.Prefix
n, err = netip.ParsePrefix(id)
if err != nil {
continue
}
if ipnet.Contains(ip) {
if n.Contains(ip) {
return c, true
}
}
@@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}
macFound := clients.dhcpServer.FindMACbyIP(ip)
macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice())
if macFound == nil {
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
var mac net.HardwareAddr
mac, err = net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(hwAddr, macFound) {
if bytes.Equal(mac, macFound) {
return c, true
}
}
@@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)
return nil, false
}
rc, ok = clients.ipToRC[ipAddr]
return rc, ok
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
if ip == nil {
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if ip == (netip.Addr{}) {
return nil, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
return clients.findRuntimeClientLocked(ip)
rc, ok = clients.ipToRC[ip]
return rc, ok
}
// check validates the client.
@@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) {
for i, id := range c.IDs {
// Normalize structured data.
var ip net.IP
var ipnet *net.IPNet
var mac net.HardwareAddr
if ip = net.ParseIP(id); ip != nil {
var (
ip netip.Addr
n netip.Prefix
mac net.HardwareAddr
)
if ip, err = netip.ParseAddr(id); err == nil {
c.IDs[i] = ip.String()
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
ipnet.IP = ip
c.IDs[i] = ipnet.String()
} else if n, err = netip.ParsePrefix(id); err == nil {
c.IDs[i] = n.String()
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
@@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
}
// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
return
}
rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if ok {
rc.WHOISInfo = wi
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
@@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
rc.WHOISInfo = wi
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)
return
}
clients.ipToRC[ipAddr] = rc
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
func (clients *clientsContainer) AddHost(
ip netip.Addr,
host string,
src clientSource,
) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
return false, fmt.Errorf("adding host: %w", err)
}
return clients.addHostLocked(ipAddr, host, src), nil
return clients.addHostLocked(ip, host, src)
}
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be

View File

@@ -22,8 +22,18 @@ func TestClients(t *testing.T) {
clients.Init(nil, nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
var (
cliNone = "1.2.3.4"
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"
cliNoneIP = netip.MustParseAddr(cliNone)
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
)
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
}
@@ -33,7 +43,7 @@ func TestClients(t *testing.T) {
assert.True(t, ok)
c = &Client{
IDs: []string{"2.2.2.2"},
IDs: []string{cli2},
Name: "client2",
}
@@ -42,7 +52,7 @@ func TestClients(t *testing.T) {
assert.True(t, ok)
c, ok = clients.Find("1.1.1.1")
c, ok = clients.Find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
@@ -52,14 +62,14 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2")
c, ok = clients.Find(cli2)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(cliNoneIP, ClientSourceHostsFile))
assert.True(t, clients.exists(cli1IP, ClientSourceHostsFile))
assert.True(t, clients.exists(cli2IP, ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
@@ -103,23 +113,31 @@ func TestClients(t *testing.T) {
})
t.Run("update_success", func(t *testing.T) {
var (
cliOld = "1.1.1.1"
cliNew = "1.1.1.2"
cliOldIP = netip.MustParseAddr(cliOld)
cliNewIP = netip.MustParseAddr(cliNew)
)
err := clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
IDs: []string{cliNew},
Name: "client1",
})
require.NoError(t, err)
assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(cliOldIP, ClientSourceHostsFile))
assert.True(t, clients.exists(cliNewIP, ClientSourceHostsFile))
err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
IDs: []string{cliNew},
Name: "client1-renamed",
UseOwnSettings: true,
})
require.NoError(t, err)
c, ok := clients.Find("1.1.1.2")
c, ok := clients.Find(cliNew)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
@@ -132,14 +150,14 @@ func TestClients(t *testing.T) {
require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0])
assert.Equal(t, cliNew, c.IDs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed")
require.True(t, ok)
assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(netip.MustParseAddr("1.1.1.2"), ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
@@ -148,45 +166,33 @@ func TestClients(t *testing.T) {
})
t.Run("addhost_success", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host", ClientSourceARP)
assert.True(t, ok)
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
require.NoError(t, err)
ok = clients.AddHost(ip, "host2", ClientSourceARP)
assert.True(t, ok)
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
require.NoError(t, err)
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceHostsFile))
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := net.IP{1, 2, 3, 4}
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
require.NoError(t, err)
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceARP))
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
require.NoError(t, err)
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceDHCP))
})
t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
assert.False(t, ok)
})
}
@@ -203,7 +209,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
@@ -212,12 +218,10 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok, err := clients.AddHost(ip.AsSlice(), "host", ClientSourceRDNS)
require.NoError(t, err)
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
@@ -234,7 +238,7 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err)
assert.True(t, ok)
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.Nil(t, rc)
@@ -249,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
clients.Init(nil, nil, nil, nil)
t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
ok, err := clients.Add(&Client{
@@ -260,8 +264,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
require.NoError(t, err)
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
assert.True(t, ok)
})

View File

@@ -3,8 +3,8 @@ package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
)
@@ -47,8 +47,8 @@ type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
IP netip.Addr `json:"ip"`
Source clientSource `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
@@ -75,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
Name: rc.Host,
Source: rc.Source,
IP: ip.AsSlice(),
IP: ip,
}
data.RuntimeClients = append(data.RuntimeClients, cj)
@@ -218,7 +218,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
break
}
ip := net.ParseIP(idStr)
ip, _ := netip.ParseAddr(idStr)
c, ok := clients.Find(idStr)
var cj *clientJSON
if !ok {
@@ -240,7 +240,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc, ok := clients.findRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime clients

View File

@@ -71,9 +71,7 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addr := aghnet.IPv4Localhost()
addrs = appendDNSAddrs(addrs, addr)
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
} else {
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -150,8 +151,8 @@ func isRunning() bool {
}
func onDNSRequest(pctx *proxy.DNSContext) {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
if ip == nil {
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
if ip == (netip.Addr{}) {
// This would be quite weird if we get here.
return
}
@@ -160,7 +161,8 @@ func onDNSRequest(pctx *proxy.DNSContext) {
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
}
@@ -193,11 +195,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
dnsConf := config.DNS
hosts := dnsConf.BindHosts
if len(hosts) == 0 {
hosts = []netip.Addr{aghnet.IPv4Localhost()}
}
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
newConf = dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
@@ -400,15 +398,12 @@ func startDNSServer() error {
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
if ip == nil {
continue
}
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
}

View File

@@ -542,6 +542,11 @@ func run(opts options, clientBuildFS fs.FS) {
}
}
// TODO(a.garipov): This could be made much earlier and could be done on
// the first run as well, but to achieve this we need to bypass requests
// over dnsforward resolver.
cmdlineUpdate(opts)
Context.web.Start()
// wait indefinitely for other go-routines to complete their job
@@ -576,7 +581,7 @@ func checkPermissions() {
}
// We should check if AdGuard Home is able to bind to port 53
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(netutil.IPv4Localhost(), defaultPortDNS))
if err != nil {
if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed.
@@ -927,3 +932,37 @@ type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// cmdlineUpdate updates current application and exits.
func cmdlineUpdate(opts options) {
if !opts.performUpdate {
return
}
log.Info("starting update")
if Context.firstRun {
log.Info("update not allowed on first run")
os.Exit(0)
}
_, err := Context.updater.VersionInfo(true)
if err != nil {
vcu := Context.updater.VersionCheckURL()
log.Error("getting version info from %s: %s", vcu, err)
os.Exit(0)
}
if Context.updater.NewVersion() == "" {
log.Info("no updates available")
os.Exit(0)
}
err = Context.updater.Update()
fatalOnError(err)
os.Exit(0)
}

View File

@@ -47,6 +47,9 @@ type options struct {
// disableUpdate, if set, makes AdGuard Home not check for updates.
disableUpdate bool
// performUpdate, if set, updates AdGuard Home without GUI and exits.
performUpdate bool
// verbose shows if verbose logging is enabled.
verbose bool
@@ -221,6 +224,14 @@ var cmdLineOpts = []cmdLineOpt{{
description: "Don't check for updates.",
longName: "no-check-update",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.performUpdate },
description: "Update application and exit.",
longName: "update",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: nil,

View File

@@ -103,6 +103,11 @@ func TestParseDisableUpdate(t *testing.T) {
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
func TestParsePerformUpdate(t *testing.T) {
assert.False(t, testParseOK(t).performUpdate, "empty is not perform update")
assert.True(t, testParseOK(t, "--update").performUpdate, "--update is perform update")
}
// TODO(e.burkov): Remove after v0.108.0.
func TestParseDisableMemoryOptimization(t *testing.T) {
o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"})
@@ -169,6 +174,10 @@ func TestOptsToArgs(t *testing.T) {
name: "disable_update",
args: []string{"--no-check-update"},
opts: options{disableUpdate: true},
}, {
name: "perform_update",
args: []string{"--update"},
opts: options{performUpdate: true},
}, {
name: "control_action",
args: []string{"-s", "run"},

View File

@@ -2,7 +2,7 @@ package home
import (
"encoding/binary"
"net"
"net/netip"
"sync/atomic"
"time"
@@ -21,7 +21,7 @@ type RDNS struct {
usePrivate uint32
// ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan net.IP
ipCh chan netip.Addr
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients the
@@ -50,7 +50,7 @@ func NewRDNS(
EnableLRU: true,
MaxCount: defaultRDNSCacheSize,
}),
ipCh: make(chan net.IP, defaultRDNSIPChSize),
ipCh: make(chan netip.Addr, defaultRDNSIPChSize),
}
if usePrivate {
rDNS.usePrivate = 1
@@ -80,9 +80,10 @@ func (r *RDNS) ensurePrivateCache() {
// isCached returns true if ip is already cached and not expired yet. It also
// caches it otherwise.
func (r *RDNS) isCached(ip net.IP) (ok bool) {
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
if expire := r.ipCache.Get(ip); len(expire) != 0 {
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
if binary.BigEndian.Uint64(expire) > now {
return true
}
@@ -91,14 +92,14 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
// The cache entry either expired or doesn't exist.
ttl := make([]byte, 8)
binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL)
r.ipCache.Set(ip, ttl)
r.ipCache.Set(ipBytes, ttl)
return false
}
// Begin adds the ip to the resolving queue if it is not cached or already
// resolved.
func (r *RDNS) Begin(ip net.IP) {
func (r *RDNS) Begin(ip netip.Addr) {
r.ensurePrivateCache()
if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) {
@@ -107,9 +108,9 @@ func (r *RDNS) Begin(ip net.IP) {
select {
case r.ipCh <- ip:
log.Tracef("rdns: %q added to queue", ip)
log.Debug("rdns: %q added to queue", ip)
default:
log.Tracef("rdns: queue is full")
log.Debug("rdns: queue is full")
}
}
@@ -119,7 +120,7 @@ func (r *RDNS) workerLoop() {
defer log.OnPanic("rdns")
for ip := range r.ipCh {
host, err := r.exchanger.Exchange(ip)
host, err := r.exchanger.Exchange(ip.AsSlice())
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
@@ -128,8 +129,6 @@ func (r *RDNS) workerLoop() {
continue
}
// Don't handle any errors since AddHost doesn't return non-nil errors
// for now.
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

View File

@@ -27,14 +27,14 @@ func TestRDNS_Begin(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
testCases := []struct {
cliIDIndex map[string]*Client
customChan chan net.IP
customChan chan netip.Addr
name string
wantLog string
req net.IP
ip netip.Addr
wantCacheHit int
wantCacheMiss int
}{{
@@ -42,7 +42,7 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "cached",
wantLog: "",
req: ip1234,
ip: ip1234,
wantCacheHit: 1,
wantCacheMiss: 0,
}, {
@@ -50,7 +50,7 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "not_cached",
wantLog: "rdns: queue is full",
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
@@ -58,15 +58,15 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "already_in_clients",
wantLog: "",
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
cliIDIndex: map[string]*Client{},
customChan: make(chan net.IP, 1),
customChan: make(chan netip.Addr, 1),
name: "add_to_queue",
wantLog: `rdns: "1.2.3.5" added to queue`,
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}}
@@ -102,7 +102,7 @@ func TestRDNS_Begin(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
rdns.Begin(tc.req)
rdns.Begin(tc.ip)
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
assert.Contains(t, w.String(), tc.wantLog)
@@ -179,8 +179,8 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
localIP := net.IP{192, 168, 1, 1}
revIPv4, err := netutil.IPToReversedAddr(localIP)
localIP := netip.MustParseAddr("192.168.1.1")
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
@@ -201,24 +201,24 @@ func TestRDNS_WorkerLoop(t *testing.T) {
testCases := []struct {
ups upstream.Upstream
cliIP netip.Addr
wantLog string
name string
cliIP net.IP
}{{
ups: locUpstream,
cliIP: localIP,
wantLog: "",
name: "all_good",
cliIP: localIP,
}, {
ups: errUpstream,
cliIP: netip.MustParseAddr("192.168.1.2"),
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error",
cliIP: net.IP{192, 168, 1, 2},
}, {
ups: locUpstream,
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
wantLog: "",
name: "ipv6_good",
cliIP: net.ParseIP("2a00:1450:400c:c06::93"),
}}
for _, tc := range testCases {
@@ -230,7 +230,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ipToRC: map[netip.Addr]*RuntimeClient{},
allTags: stringutil.NewSet(),
}
ch := make(chan net.IP)
ch := make(chan netip.Addr)
rdns := &RDNS{
exchanger: &rDNSExchanger{
ex: tc.ups,

View File

@@ -513,6 +513,11 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
return nil
}
// errNoIPInCert is the error that is returned from [parseCertChain] if the leaf
// certificate doesn't contain IPs.
const errNoIPInCert errors.Error = `certificates has no IP addresses; ` +
`DNS-over-TLS won't be advertised via DDR`
// parseCertChain parses the certificate chain from raw data, and returns it.
// If ok is true, the returned error, if any, is not critical.
func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err error) {
@@ -535,8 +540,7 @@ func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err
log.Info("tls: number of certs: %d", len(parsedCerts))
if !aghtls.CertificateHasIP(parsedCerts[0]) {
err = errors.Error(`certificate has no IP addresses` +
`, this may cause issues with DNS-over-TLS clients`)
err = errNoIPInCert
}
return parsedCerts, true, err

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"strings"
"time"
@@ -26,7 +27,7 @@ const (
// WHOIS - module context
type WHOIS struct {
clients *clientsContainer
ipChan chan net.IP
ipChan chan netip.Addr
// dialContext specifies the dial function for creating unencrypted TCP
// connections.
@@ -51,7 +52,7 @@ func initWHOIS(clients *clientsContainer) *WHOIS {
MaxCount: 10000,
}),
dialContext: customDialContext,
ipChan: make(chan net.IP, 255),
ipChan: make(chan netip.Addr, 255),
}
go w.workerLoop()
@@ -192,7 +193,7 @@ func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
}
// Request WHOIS information
func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISInfo) {
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
resp, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: error: %s IP:%s", err, ip)
@@ -220,24 +221,25 @@ func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISI
}
// Begin - begin requesting WHOIS info
func (w *WHOIS) Begin(ip net.IP) {
func (w *WHOIS) Begin(ip netip.Addr) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip))
expire := w.ipAddrs.Get(ipBytes)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
// TTL expired
}
expire = make([]byte, 8)
binary.BigEndian.PutUint64(expire, now+whoisTTL)
_ = w.ipAddrs.Set([]byte(ip), expire)
_ = w.ipAddrs.Set(ipBytes, expire)
log.Debug("whois: adding %s", ip)
select {
case w.ipChan <- ip:
//
default:
log.Debug("whois: queue is full")
}