all: sync with master; upd chlog
This commit is contained in:
@@ -30,32 +30,30 @@ import (
|
||||
const dataDir = "data"
|
||||
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
// TODO(a.garipov): Put them into a separate object.
|
||||
type logSettings struct {
|
||||
// File is the path to the log file. If empty, logs are written to stdout.
|
||||
// If "syslog", logs are written to syslog.
|
||||
File string `yaml:"log_file"`
|
||||
File string `yaml:"file"`
|
||||
|
||||
// MaxBackups is the maximum number of old log files to retain.
|
||||
//
|
||||
// NOTE: MaxAge may still cause them to get deleted.
|
||||
MaxBackups int `yaml:"log_max_backups"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
|
||||
// MaxSize is the maximum size of the log file before it gets rotated, in
|
||||
// megabytes. The default value is 100 MB.
|
||||
MaxSize int `yaml:"log_max_size"`
|
||||
MaxSize int `yaml:"max_size"`
|
||||
|
||||
// MaxAge is the maximum duration for retaining old log files, in days.
|
||||
MaxAge int `yaml:"log_max_age"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
|
||||
// Compress determines, if the rotated log files should be compressed using
|
||||
// gzip.
|
||||
Compress bool `yaml:"log_compress"`
|
||||
Compress bool `yaml:"compress"`
|
||||
|
||||
// LocalTime determines, if the time used for formatting the timestamps in
|
||||
// is the computer's local time.
|
||||
LocalTime bool `yaml:"log_localtime"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
|
||||
// Verbose determines, if verbose (aka debug) logging is enabled.
|
||||
Verbose bool `yaml:"verbose"`
|
||||
@@ -142,7 +140,8 @@ type configuration struct {
|
||||
// Keep this field sorted to ensure consistent ordering.
|
||||
Clients *clientsConfig `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
// Log is a block with log configuration settings.
|
||||
Log logSettings `yaml:"log"`
|
||||
|
||||
OSConfig *osConfig `yaml:"os"`
|
||||
|
||||
@@ -241,6 +240,7 @@ type tlsConfigSettings struct {
|
||||
|
||||
type queryLogConfig struct {
|
||||
// Ignored is the list of host names, which should not be written to log.
|
||||
// "." is considered to be the root domain.
|
||||
Ignored []string `yaml:"ignored"`
|
||||
|
||||
// Interval is the interval for query log's files rotation.
|
||||
@@ -390,7 +390,7 @@ var config = &configuration{
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
logSettings: logSettings{
|
||||
Log: logSettings{
|
||||
Compress: false,
|
||||
LocalTime: false,
|
||||
MaxBackups: 0,
|
||||
@@ -421,19 +421,19 @@ func (c *configuration) getConfigFilename() string {
|
||||
// separate method in order to configure logger before the actual configuration
|
||||
// is parsed and applied.
|
||||
func readLogSettings() (ls *logSettings) {
|
||||
ls = &logSettings{}
|
||||
conf := &configuration{}
|
||||
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return ls
|
||||
return &logSettings{}
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(yamlFile, ls)
|
||||
err = yaml.Unmarshal(yamlFile, conf)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get logging settings from the configuration: %s", err)
|
||||
}
|
||||
|
||||
return ls
|
||||
return &conf.Log
|
||||
}
|
||||
|
||||
// validateBindHosts returns error if any of binding hosts from configuration is
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -167,30 +168,77 @@ func initDNSServer(
|
||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||
}
|
||||
|
||||
if config.Clients.Sources.RDNS {
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
}
|
||||
|
||||
initRDNS()
|
||||
initWHOIS()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for rDNS and WHOIS
|
||||
// processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache for rDNS and WHOIS
|
||||
// processing. It must be greater than zero.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for IP addresses cached by
|
||||
// rDNS and WHOIS.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
// initRDNS initializes the rDNS.
|
||||
func initRDNS() {
|
||||
Context.rdnsCh = make(chan netip.Addr, defaultQueueSize)
|
||||
|
||||
// TODO(s.chzhen): Add ability to disable it on dns server configuration
|
||||
// update in [dnsforward] package.
|
||||
r := rdns.New(&rdns.Config{
|
||||
Exchanger: Context.dnsServer,
|
||||
CacheSize: defaultCacheSize,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
|
||||
go processRDNS(r)
|
||||
}
|
||||
|
||||
// processRDNS processes reverse DNS lookup queries. It is intended to be used
|
||||
// as a goroutine.
|
||||
func processRDNS(r rdns.Interface) {
|
||||
defer log.OnPanic("rdns")
|
||||
|
||||
for ip := range Context.rdnsCh {
|
||||
ok := Context.dnsServer.ShouldResolveClient(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
host, changed := r.Process(ip)
|
||||
if host == "" || !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
ok = Context.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"dns: can't set rdns info for client %q: already set with higher priority source",
|
||||
ip,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// initWHOIS initializes the WHOIS.
|
||||
//
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
func initWHOIS() {
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
defaultTimeout = 5 * time.Second
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache. If it's zero,
|
||||
// cache size is unlimited.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from
|
||||
// net.Conn.
|
||||
defaultMaxConnReadSize = 64 * 1024
|
||||
@@ -200,9 +248,6 @@ func initWHOIS() {
|
||||
|
||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
||||
defaultMaxInfoLen = 250
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for cached IP addresses.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
|
||||
@@ -274,11 +319,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
return
|
||||
}
|
||||
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
Context.rdnsCh <- ip
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
@@ -517,11 +558,7 @@ func startDNSServer() error {
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
Context.rdnsCh <- ip
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ type homeContext struct {
|
||||
stats stats.Interface // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
@@ -83,6 +82,9 @@ type homeContext struct {
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
|
||||
// rdnsCh is the channel for receiving IPs for rDNS processing.
|
||||
rdnsCh chan netip.Addr
|
||||
|
||||
// whoisCh is the channel for receiving IPs for WHOIS processing.
|
||||
whoisCh chan netip.Addr
|
||||
|
||||
@@ -468,7 +470,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
ServiceName: pcService,
|
||||
TXTSuffix: pcTXTSuffix,
|
||||
CacheTime: cacheTime,
|
||||
CacheSize: conf.SafeBrowsingCacheSize,
|
||||
CacheSize: conf.ParentalCacheSize,
|
||||
})
|
||||
|
||||
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
@@ -829,20 +831,21 @@ func configureLogger(opts options) (err error) {
|
||||
// getLogSettings returns a log settings object properly initialized from opts.
|
||||
func getLogSettings(opts options) (ls *logSettings) {
|
||||
ls = readLogSettings()
|
||||
configLogSettings := config.Log
|
||||
|
||||
// Command-line arguments can override config settings.
|
||||
if opts.verbose || config.Verbose {
|
||||
if opts.verbose || configLogSettings.Verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
|
||||
ls.File = stringutil.Coalesce(opts.logFile, config.File, ls.File)
|
||||
ls.File = stringutil.Coalesce(opts.logFile, configLogSettings.File, ls.File)
|
||||
|
||||
// Handle default log settings overrides.
|
||||
ls.Compress = config.Compress
|
||||
ls.LocalTime = config.LocalTime
|
||||
ls.MaxBackups = config.MaxBackups
|
||||
ls.MaxSize = config.MaxSize
|
||||
ls.MaxAge = config.MaxAge
|
||||
ls.Compress = configLogSettings.Compress
|
||||
ls.LocalTime = configLogSettings.LocalTime
|
||||
ls.MaxBackups = configLogSettings.MaxBackups
|
||||
ls.MaxSize = configLogSettings.MaxSize
|
||||
ls.MaxAge = configLogSettings.MaxAge
|
||||
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// RDNS resolves clients' addresses to enrich their metadata.
|
||||
type RDNS struct {
|
||||
exchanger dnsforward.RDNSExchanger
|
||||
clients *clientsContainer
|
||||
|
||||
// ipCh used to pass client's IP to rDNS workerLoop.
|
||||
ipCh chan netip.Addr
|
||||
|
||||
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
||||
// address stays here while it's inside clients. After leaving clients the
|
||||
// address will be resolved once again. If the address couldn't be
|
||||
// resolved, cache prevents further attempts to resolve it for some time.
|
||||
ipCache cache.Cache
|
||||
|
||||
// usePrivate stores the state of current private reverse-DNS resolving
|
||||
// settings.
|
||||
usePrivate atomic.Bool
|
||||
}
|
||||
|
||||
// Default AdGuard Home reverse DNS values.
|
||||
const (
|
||||
revDNSCacheSize = 10000
|
||||
|
||||
// TODO(e.burkov): Make these values configurable.
|
||||
revDNSCacheTTL = 24 * 60 * 60
|
||||
revDNSFailureCacheTTL = 1 * 60 * 60
|
||||
|
||||
revDNSQueueSize = 256
|
||||
)
|
||||
|
||||
// NewRDNS creates and returns initialized RDNS.
|
||||
func NewRDNS(
|
||||
exchanger dnsforward.RDNSExchanger,
|
||||
clients *clientsContainer,
|
||||
usePrivate bool,
|
||||
) (rDNS *RDNS) {
|
||||
rDNS = &RDNS{
|
||||
exchanger: exchanger,
|
||||
clients: clients,
|
||||
ipCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: revDNSCacheSize,
|
||||
}),
|
||||
ipCh: make(chan netip.Addr, revDNSQueueSize),
|
||||
}
|
||||
|
||||
rDNS.usePrivate.Store(usePrivate)
|
||||
|
||||
go rDNS.workerLoop()
|
||||
|
||||
return rDNS
|
||||
}
|
||||
|
||||
// ensurePrivateCache ensures that the state of the RDNS cache is consistent
|
||||
// with the current private client RDNS resolving settings.
|
||||
//
|
||||
// TODO(e.burkov): Clearing cache each time this value changed is not a perfect
|
||||
// approach since only unresolved locally-served addresses should be removed.
|
||||
// Implement when improving the cache.
|
||||
func (r *RDNS) ensurePrivateCache() {
|
||||
usePrivate := r.exchanger.ResolvesPrivatePTR()
|
||||
if r.usePrivate.CompareAndSwap(!usePrivate, usePrivate) {
|
||||
r.ipCache.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
// isCached returns true if ip is already cached and not expired yet. It also
|
||||
// caches it otherwise.
|
||||
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
|
||||
ipBytes := ip.AsSlice()
|
||||
now := uint64(time.Now().Unix())
|
||||
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
|
||||
return binary.BigEndian.Uint64(expire) > now
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// cache caches the ip address for ttl seconds.
|
||||
func (r *RDNS) cache(ip netip.Addr, ttl uint64) {
|
||||
ipData := ip.AsSlice()
|
||||
|
||||
ttlData := [8]byte{}
|
||||
binary.BigEndian.PutUint64(ttlData[:], uint64(time.Now().Unix())+ttl)
|
||||
|
||||
r.ipCache.Set(ipData, ttlData[:])
|
||||
}
|
||||
|
||||
// Begin adds the ip to the resolving queue if it is not cached or already
|
||||
// resolved.
|
||||
func (r *RDNS) Begin(ip netip.Addr) {
|
||||
r.ensurePrivateCache()
|
||||
|
||||
if r.isCached(ip) || r.clients.clientSource(ip) > ClientSourceRDNS {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case r.ipCh <- ip:
|
||||
log.Debug("rdns: %q added to queue", ip)
|
||||
default:
|
||||
log.Debug("rdns: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// workerLoop handles incoming IP addresses from ipChan and adds it into
|
||||
// clients.
|
||||
func (r *RDNS) workerLoop() {
|
||||
defer log.OnPanic("rdns")
|
||||
|
||||
for ip := range r.ipCh {
|
||||
ttl := uint64(revDNSCacheTTL)
|
||||
|
||||
host, err := r.exchanger.Exchange(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
if errors.Is(err, dnsforward.ErrRDNSFailed) {
|
||||
// Cache failure for a less time.
|
||||
ttl = revDNSFailureCacheTTL
|
||||
}
|
||||
}
|
||||
|
||||
r.cache(ip, ttl)
|
||||
|
||||
if host != "" {
|
||||
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRDNS_Begin(t *testing.T) {
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
|
||||
|
||||
testCases := []struct {
|
||||
cliIDIndex map[string]*Client
|
||||
customChan chan netip.Addr
|
||||
name string
|
||||
wantLog string
|
||||
ip netip.Addr
|
||||
wantCacheHit int
|
||||
wantCacheMiss int
|
||||
}{{
|
||||
cliIDIndex: map[string]*Client{},
|
||||
customChan: nil,
|
||||
name: "cached",
|
||||
wantLog: "",
|
||||
ip: ip1234,
|
||||
wantCacheHit: 1,
|
||||
wantCacheMiss: 0,
|
||||
}, {
|
||||
cliIDIndex: map[string]*Client{},
|
||||
customChan: nil,
|
||||
name: "not_cached",
|
||||
wantLog: "rdns: queue is full",
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}, {
|
||||
cliIDIndex: map[string]*Client{"1.2.3.5": {}},
|
||||
customChan: nil,
|
||||
name: "already_in_clients",
|
||||
wantLog: "",
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}, {
|
||||
cliIDIndex: map[string]*Client{},
|
||||
customChan: make(chan netip.Addr, 1),
|
||||
name: "add_to_queue",
|
||||
wantLog: `rdns: "1.2.3.5" added to queue`,
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w.Reset()
|
||||
|
||||
ipCache := cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: revDNSCacheSize,
|
||||
})
|
||||
ttl := make([]byte, binary.Size(uint64(0)))
|
||||
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
},
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
ipToRC: map[netip.Addr]*RuntimeClient{},
|
||||
allTags: stringutil.NewSet(),
|
||||
},
|
||||
}
|
||||
ipCache.Clear()
|
||||
ipCache.Set(net.IP{1, 2, 3, 4}, ttl)
|
||||
|
||||
if tc.customChan != nil {
|
||||
rdns.ipCh = tc.customChan
|
||||
defer close(tc.customChan)
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rdns.Begin(tc.ip)
|
||||
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
|
||||
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
|
||||
assert.Contains(t, w.String(), tc.wantLog)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
|
||||
type rDNSExchanger struct {
|
||||
ex upstream.Upstream
|
||||
usePrivate bool
|
||||
}
|
||||
|
||||
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
|
||||
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
|
||||
rev, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(rev),
|
||||
Qclass: dns.ClassINET,
|
||||
Qtype: dns.TypePTR,
|
||||
}},
|
||||
}
|
||||
|
||||
resp, err := e.ex.Exchange(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(resp.Answer) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return resp.Answer[0].Header().Name, nil
|
||||
}
|
||||
|
||||
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
|
||||
func (e *rDNSExchanger) ResolvesPrivatePTR() (ok bool) {
|
||||
return e.usePrivate
|
||||
}
|
||||
|
||||
func TestRDNS_ensurePrivateCache(t *testing.T) {
|
||||
data := []byte{1, 2, 3, 4}
|
||||
|
||||
ipCache := cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: revDNSCacheSize,
|
||||
})
|
||||
|
||||
ex := &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
}
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
exchanger: ex,
|
||||
}
|
||||
|
||||
rdns.ipCache.Set(data, data)
|
||||
require.NotZero(t, rdns.ipCache.Stats().Count)
|
||||
|
||||
ex.usePrivate = !ex.usePrivate
|
||||
|
||||
rdns.ensurePrivateCache()
|
||||
require.Zero(t, rdns.ipCache.Stats().Count)
|
||||
}
|
||||
|
||||
func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
localIP := netip.MustParseAddr("192.168.1.1")
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return aghalg.Coalesce(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revIPv4, "local.domain"),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, revIPv6, "ipv6.domain"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
), nil
|
||||
},
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
|
||||
testCases := []struct {
|
||||
ups upstream.Upstream
|
||||
cliIP netip.Addr
|
||||
wantLog string
|
||||
name string
|
||||
wantClientSource clientSource
|
||||
}{{
|
||||
ups: locUpstream,
|
||||
cliIP: localIP,
|
||||
wantLog: "",
|
||||
name: "all_good",
|
||||
wantClientSource: ClientSourceRDNS,
|
||||
}, {
|
||||
ups: errUpstream,
|
||||
cliIP: netip.MustParseAddr("192.168.1.2"),
|
||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||
name: "resolve_error",
|
||||
wantClientSource: ClientSourceNone,
|
||||
}, {
|
||||
ups: locUpstream,
|
||||
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
|
||||
wantLog: "",
|
||||
name: "ipv6_good",
|
||||
wantClientSource: ClientSourceRDNS,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w.Reset()
|
||||
|
||||
cc := newClientsContainer(t)
|
||||
ch := make(chan netip.Addr)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: tc.ups,
|
||||
},
|
||||
clients: cc,
|
||||
ipCh: ch,
|
||||
ipCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: revDNSCacheSize,
|
||||
}),
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
rdns.workerLoop()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
ch <- tc.cliIP
|
||||
close(ch)
|
||||
wg.Wait()
|
||||
|
||||
if tc.wantLog != "" {
|
||||
assert.Contains(t, w.String(), tc.wantLog)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantClientSource, cc.clientSource(tc.cliIP))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 23
|
||||
const currentSchemaVersion = 24
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
@@ -98,6 +98,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema20to21,
|
||||
upgradeSchema21to22,
|
||||
upgradeSchema22to23,
|
||||
upgradeSchema23to24,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -1325,6 +1326,110 @@ func upgradeSchema22to23(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema23to24 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'log_file': ""
|
||||
// 'log_max_backups': 0
|
||||
// 'log_max_size': 100
|
||||
// 'log_max_age': 3
|
||||
// 'log_compress': false
|
||||
// 'log_localtime': false
|
||||
// 'verbose': false
|
||||
//
|
||||
// # AFTER:
|
||||
// 'log':
|
||||
// 'file': ""
|
||||
// 'max_backups': 0
|
||||
// 'max_size': 100
|
||||
// 'max_age': 3
|
||||
// 'compress': false
|
||||
// 'local_time': false
|
||||
// 'verbose': false
|
||||
func upgradeSchema23to24(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 23 to 24")
|
||||
diskConf["schema_version"] = 24
|
||||
|
||||
logObj := yobj{}
|
||||
err = coalesceError(
|
||||
moveField[string](diskConf, logObj, "log_file", "file"),
|
||||
moveField[int](diskConf, logObj, "log_max_backups", "max_backups"),
|
||||
moveField[int](diskConf, logObj, "log_max_size", "max_size"),
|
||||
moveField[int](diskConf, logObj, "log_max_age", "max_age"),
|
||||
moveField[bool](diskConf, logObj, "log_compress", "compress"),
|
||||
moveField[bool](diskConf, logObj, "log_localtime", "local_time"),
|
||||
moveField[bool](diskConf, logObj, "verbose", "verbose"),
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if len(logObj) != 0 {
|
||||
diskConf["log"] = logObj
|
||||
}
|
||||
|
||||
delete(diskConf, "log_file")
|
||||
delete(diskConf, "log_max_backups")
|
||||
delete(diskConf, "log_max_size")
|
||||
delete(diskConf, "log_max_age")
|
||||
delete(diskConf, "log_compress")
|
||||
delete(diskConf, "log_localtime")
|
||||
delete(diskConf, "verbose")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// moveField gets field value for key from diskConf, and then set this value
|
||||
// in newConf for newKey.
|
||||
func moveField[T any](diskConf, newConf yobj, key, newKey string) (err error) {
|
||||
ok, newVal, err := fieldValue[T](diskConf, key)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
switch v := newVal.(type) {
|
||||
case int, bool, string:
|
||||
newConf[newKey] = v
|
||||
default:
|
||||
return fmt.Errorf("invalid type of %s: %T", key, newVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fieldValue returns the value of type T for key in diskConf object.
|
||||
func fieldValue[T any](diskConf yobj, key string) (ok bool, field any, err error) {
|
||||
fieldVal, ok := diskConf[key]
|
||||
if !ok {
|
||||
return false, new(T), nil
|
||||
}
|
||||
|
||||
f, ok := fieldVal.(T)
|
||||
if !ok {
|
||||
return false, nil, fmt.Errorf("unexpected type of %s: %T", key, fieldVal)
|
||||
}
|
||||
|
||||
return true, f, nil
|
||||
}
|
||||
|
||||
// coalesceError returns the first non-nil error. It is named after function
|
||||
// COALESCE in SQL. If all errors are nil, it returns nil.
|
||||
//
|
||||
// TODO(a.garipov): Consider a similar helper to group errors together to show
|
||||
// as many errors as possible.
|
||||
//
|
||||
// TODO(a.garipov): Think of ways to merge with [aghalg.Coalesce].
|
||||
func coalesceError(errors ...error) (res error) {
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -1306,3 +1306,76 @@ func TestUpgradeSchema22to23(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema23to24(t *testing.T) {
|
||||
const newSchemaVer = 24
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "empty",
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "ok",
|
||||
in: yobj{
|
||||
"log_file": "/test/path.log",
|
||||
"log_max_backups": 1,
|
||||
"log_max_size": 2,
|
||||
"log_max_age": 3,
|
||||
"log_compress": true,
|
||||
"log_localtime": true,
|
||||
"verbose": true,
|
||||
},
|
||||
want: yobj{
|
||||
"log": yobj{
|
||||
"file": "/test/path.log",
|
||||
"max_backups": 1,
|
||||
"max_size": 2,
|
||||
"max_age": 3,
|
||||
"compress": true,
|
||||
"local_time": true,
|
||||
"verbose": true,
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "invalid",
|
||||
in: yobj{
|
||||
"log_file": "/test/path.log",
|
||||
"log_max_backups": 1,
|
||||
"log_max_size": 2,
|
||||
"log_max_age": 3,
|
||||
"log_compress": "",
|
||||
"log_localtime": true,
|
||||
"verbose": true,
|
||||
},
|
||||
want: yobj{
|
||||
"log_file": "/test/path.log",
|
||||
"log_max_backups": 1,
|
||||
"log_max_size": 2,
|
||||
"log_max_age": 3,
|
||||
"log_compress": "",
|
||||
"log_localtime": true,
|
||||
"verbose": true,
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
wantErrMsg: "unexpected type of log_compress: string",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema23to24(tc.in)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user