all: sync more

This commit is contained in:
Ainar Garipov
2022-06-02 17:55:48 +03:00
parent dcb043df5f
commit 4f5131f423
31 changed files with 1073 additions and 900 deletions

View File

@@ -4,8 +4,6 @@ import (
"bytes"
"fmt"
"net"
"os/exec"
"runtime"
"sort"
"strings"
"sync"
@@ -62,6 +60,16 @@ const (
ClientSourceHostsFile
)
// clientSourceConf is used to configure where the runtime clients will be
// obtained from.
type clientSourcesConf struct {
WHOIS bool `yaml:"whois"`
ARP bool `yaml:"arp"`
RDNS bool `yaml:"rdns"`
DHCP bool `yaml:"dhcp"`
HostsFile bool `yaml:"hosts"`
}
// RuntimeClient information
type RuntimeClient struct {
WHOISInfo *RuntimeClientWHOISInfo
@@ -99,6 +107,9 @@ type clientsContainer struct {
// hosts database.
etcHosts *aghnet.HostsContainer
// arpdb stores the neighbors retrieved from ARP.
arpdb aghnet.ARPDB
testing bool // if TRUE, this object is used for internal tests
}
@@ -109,6 +120,7 @@ func (clients *clientsContainer) Init(
objects []*clientObject,
dhcpServer *dhcpd.Server,
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
) {
if clients.list != nil {
log.Fatal("clients.list != nil")
@@ -121,6 +133,7 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects)
if clients.testing {
@@ -132,14 +145,14 @@ func (clients *clientsContainer) Init(
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
go clients.handleHostsUpdates()
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
}
func (clients *clientsContainer) handleHostsUpdates() {
if clients.etcHosts != nil {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
}
@@ -156,7 +169,9 @@ func (clients *clientsContainer) Start() {
// Reload reloads runtime clients.
func (clients *clientsContainer) Reload() {
clients.addFromSystemARP()
if clients.arpdb != nil {
clients.addFromSystemARP()
}
}
type clientObject struct {
@@ -255,6 +270,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
}
func (clients *clientsContainer) periodicUpdate() {
defer log.OnPanic("clients container")
for {
clients.Reload()
time.Sleep(clientsUpdatePeriod)
@@ -733,6 +750,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
return false
}
rc.Host = host
rc.Source = src
} else {
rc = &RuntimeClient{
@@ -805,16 +823,18 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
if err := clients.arpdb.Refresh(); err != nil {
log.Error("refreshing arp container: %s", err)
clients.arpdb = aghnet.EmptyARPDB{}
return
}
cmd := exec.Command("arp", "-a")
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
data, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Debug("command %q has failed: %q code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
ns := clients.arpdb.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
return
}
@@ -823,36 +843,20 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.rmHostsBySrc(ClientSourceARP)
n := 0
// TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
lparen := strings.Index(ln, " (")
rparen := strings.Index(ln, ") ")
if lparen == -1 || rparen == -1 || lparen >= rparen {
continue
}
host := ln[:lparen]
ipStr := ln[lparen+2 : rparen]
ip := net.ParseIP(ipStr)
if netutil.ValidateDomainName(host) != nil || ip == nil {
continue
}
ok := clients.addHostLocked(ip, host, ClientSourceARP)
if ok {
n++
added := 0
for _, n := range ns {
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
added++
}
}
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
log.Debug("clients: added %d client aliases from arp neighborhood", added)
}
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) updateFromDHCP(add bool) {
if clients.dhcpServer == nil {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}

View File

@@ -3,11 +3,12 @@ package home
import (
"net"
"os"
"runtime"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -16,7 +17,7 @@ func TestClients(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
c := &Client{
@@ -192,7 +193,7 @@ func TestClientsWHOIS(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
whois := &RuntimeClientWHOISInfo{
Country: "AU",
Orgname: "Example Org",
@@ -251,7 +252,7 @@ func TestClientsAddExisting(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
@@ -271,12 +272,18 @@ func TestClientsAddExisting(t *testing.T) {
})
t.Run("complicated", func(t *testing.T) {
// TODO(a.garipov): Properly decouple the DHCP server from the client
// storage.
if runtime.GOOS == "windows" {
t.Skip("skipping dhcp test on windows")
}
var err error
ip := net.IP{1, 2, 3, 4}
// First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{
config := &dhcpd.ServerConfig{
Enabled: true,
DBFilePath: "leases.db",
Conf4: dhcpd.V4ServerConf{
@@ -290,10 +297,9 @@ func TestClientsAddExisting(t *testing.T) {
clients.dhcpServer, err = dhcpd.Create(config)
require.NoError(t, err)
// TODO(e.burkov): leases.db isn't created on Windows so removing it
// causes an error. Split the test to make it run properly on different
// operating systems.
t.Cleanup(func() { _ = os.Remove("leases.db") })
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return os.Remove("leases.db")
})
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
@@ -325,7 +331,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
// Add client with upstreams.
ok, err := clients.Add(&Client{

View File

@@ -51,6 +51,13 @@ type osConfig struct {
RlimitNoFile uint64 `yaml:"rlimit_nofile"`
}
type clientsConfig struct {
// Sources defines the set of sources to fetch the runtime clients from.
Sources *clientSourcesConf `yaml:"runtime_sources"`
// Persistent are the configured clients.
Persistent []*clientObject `yaml:"persistent"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
@@ -83,12 +90,12 @@ type configuration struct {
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
// Clients contains the YAML representations of the persistent clients.
// This field is only used for reading and writing persistent client data.
// Keep this field sorted to ensure consistent ordering.
Clients []*clientObject `yaml:"clients"`
Clients *clientsConfig `yaml:"clients"`
logSettings `yaml:",inline"`
@@ -123,13 +130,9 @@ type dnsConfig struct {
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
// LocalDomainName is the domain name used for known internal hosts.
// For example, a machine called "myhost" can be addressed as
// "myhost.lan" when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
// ResolveClients enables and disables resolving clients with RDNS.
ResolveClients bool `yaml:"resolve_clients"`
// PrivateNets is the set of IP networks for which the private reverse DNS
// resolver should be used.
PrivateNets []string `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
@@ -199,8 +202,6 @@ var config = &configuration{
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
LocalDomainName: "lan",
ResolveClients: true,
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
@@ -208,6 +209,18 @@ var config = &configuration{
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC,
},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
},
Clients: &clientsConfig{
Sources: &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: true,
DHCP: true,
HostsFile: true,
},
},
logSettings: logSettings{
LogCompress: false,
LogLocalTime: false,
@@ -403,18 +416,16 @@ func (c *configuration) write() error {
s.WriteDiskConfig(&c)
dns := &config.DNS
dns.FilteringConfig = c
dns.LocalPTRResolvers,
dns.ResolveClients,
dns.UsePrivateRDNS = s.RDNSSettings()
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
}
if Context.dhcpServer != nil {
c := dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(&c)
c := &dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(c)
config.DHCP = c
}
config.Clients = Context.clients.forConfig()
config.Clients.Persistent = Context.clients.forConfig()
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)

View File

@@ -13,6 +13,7 @@ import (
"runtime"
"strings"
"time"
"unicode/utf8"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -359,6 +360,9 @@ func shutdownSrv(ctx context.Context, srv *http.Server) {
}
}
// PasswordMinRunes is the minimum length of user's password in runes.
const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
@@ -368,6 +372,18 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
return
}
if utf8.RuneCountInString(req.Password) < PasswordMinRunes {
aghhttp.Error(
r,
w,
http.StatusUnprocessableEntity,
"password must be at least %d symbols long",
PasswordMinRunes,
)
return
}
err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@@ -77,13 +77,36 @@ func initDNSServer() (err error) {
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil)
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
}
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
SubnetDetector: Context.subnetDetector,
Anonymizer: anonymizer,
LocalDomain: config.DNS.LocalDomainName,
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
PrivateNets: privateNets,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
@@ -112,8 +135,13 @@ func initDNSServer() (err error) {
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
Context.whois = initWHOIS(&Context.clients)
if config.Clients.Sources.RDNS {
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
if config.Clients.Sources.WHOIS {
Context.whois = initWHOIS(&Context.clients)
}
Context.filters.Init()
return nil
@@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) {
return
}
if config.DNS.ResolveClients && !ip.IsLoopback() {
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}
@@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.TLSConfig = tlsConf.TLSConfig
newConf.TLSConfig.ServerName = tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
}
if tlsConf.PortDNSOverTLS != 0 {
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
}
@@ -217,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.FilterHandler = applyAdditionalFiltering
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
newConf.ResolveClients = dnsConf.ResolveClients
newConf.ResolveClients = config.Clients.Sources.RDNS
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
@@ -365,10 +398,15 @@ func startDNSServer() error {
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if config.DNS.ResolveClients && !ip.IsLoopback() {
if ip == nil {
continue
}
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}

View File

@@ -65,8 +65,6 @@ type homeContext struct {
updater *updater.Updater
subnetDetector *aghnet.SubnetDetector
// mux is our custom http.ServeMux.
mux *http.ServeMux
@@ -175,6 +173,11 @@ func setupContext(args options) {
os.Exit(0)
}
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
fatalOnError(err)
}
}
Context.mux = http.NewServeMux()
@@ -182,7 +185,7 @@ func setupContext(args options) {
// logIfUnsupported logs a formatted warning if the error is one of the
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherise, it returns err.
// nil. Otherwise, it returns err.
func logIfUnsupported(msg string, err error) (outErr error) {
if errors.As(err, new(*aghos.UnsupportedError)) {
log.Debug(msg, err)
@@ -287,13 +290,12 @@ func setupConfig(args options) (err error) {
ConfName: config.getConfigFilename(),
})
if !args.noEtcHosts {
if err = setupHostsContainer(); err != nil {
return err
}
var arpdb aghnet.ARPDB
if config.Clients.Sources.ARP {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
if args.bindPort != 0 {
uc := aghalg.UniqChecker{}
@@ -466,9 +468,6 @@ func run(args options, clientBuildFS fs.FS) {
Context.web, err = initWeb(args, clientBuildFS)
fatalOnError(err)
Context.subnetDetector, err = aghnet.NewSubnetDetector()
fatalOnError(err)
if !Context.firstRun {
err = initDNSServer()
fatalOnError(err)
@@ -527,8 +526,10 @@ func checkPermissions() {
if err != nil {
if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports.
You have two options:
1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:

View File

@@ -230,13 +230,19 @@ var helpArg = arg{
}
var noEtcHostsArg = arg{
description: "Do not use the OS-provided hosts.",
description: "Deprecated. Do not use the OS-provided hosts.",
longName: "no-etc-hosts",
shortName: "",
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
effect: nil,
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
effect: func(_ options, _ string) (f effect, err error) {
log.Info(
"warning: --no-etc-hosts flag is deprecated and will be removed in the future versions",
)
return nil, nil
},
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
}
var localFrontendArg = arg{

View File

@@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
locUpstream := &aghtest.TestUpstream{
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"192.168.1.1": {"local.domain"},
"2a00:1450:400c:c06::93": {"ipv6.domain"},

View File

@@ -21,9 +21,11 @@ import (
)
// currentSchemaVersion is the current schema version.
const currentSchemaVersion = 12
const currentSchemaVersion = 14
// These aliases are provided for convenience.
//
// TODO(e.burkov): Remove any after updating to Go 1.18.
type (
any = interface{}
yarr = []any
@@ -85,6 +87,8 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema9to10,
upgradeSchema10to11,
upgradeSchema11to12,
upgradeSchema12to13,
upgradeSchema13to14,
}
n := 0
@@ -690,6 +694,114 @@ func upgradeSchema11to12(diskConf yobj) (err error) {
return nil
}
// upgradeSchema12to13 performs the following changes:
//
// # BEFORE:
// 'dns':
// # …
// 'local_domain_name': 'lan'
//
// # AFTER:
// 'dhcp':
// # …
// 'local_domain_name': 'lan'
//
func upgradeSchema12to13(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 12 to 13")
diskConf["schema_version"] = 13
dnsVal, ok := diskConf["dns"]
if !ok {
return nil
}
var dns yobj
dns, ok = dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
dhcpVal, ok := diskConf["dhcp"]
if !ok {
return nil
}
var dhcp yobj
dhcp, ok = dhcpVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dhcp: %T", dhcpVal)
}
const field = "local_domain_name"
dhcp[field] = dns[field]
delete(dns, field)
return nil
}
// upgradeSchema13to14 performs the following changes:
//
// # BEFORE:
// 'clients':
// - 'name': 'client-name'
// # …
//
// # AFTER:
// 'clients':
// 'persistent':
// - 'name': 'client-name'
// # …
// 'runtime_sources':
// 'whois': true
// 'arp': true
// 'rdns': true
// 'dhcp': true
// 'hosts': true
//
func upgradeSchema13to14(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 13 to 14")
diskConf["schema_version"] = 14
clientsVal, ok := diskConf["clients"]
if !ok {
clientsVal = yarr{}
}
var rdnsSrc bool
if dnsVal, dok := diskConf["dns"]; dok {
var dnsSettings yobj
dnsSettings, ok = dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
var rdnsSrcVal any
rdnsSrcVal, ok = dnsSettings["resolve_clients"]
if ok {
rdnsSrc, ok = rdnsSrcVal.(bool)
if !ok {
return fmt.Errorf("unexpected type of resolve_clients: %T", rdnsSrcVal)
}
delete(dnsSettings, "resolve_clients")
}
}
diskConf["clients"] = yobj{
"persistent": clientsVal,
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: rdnsSrc,
DHCP: true,
HostsFile: true,
},
}
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging
// package.
func funcName() string {

View File

@@ -55,7 +55,7 @@ func TestUpgradeSchema2to3(t *testing.T) {
require.Len(t, v, 1)
require.Equal(t, "8.8.8.8:53", v[0])
default:
t.Fatalf("wrong type for bootsrap dns: %T", v)
t.Fatalf("wrong type for bootstrap dns: %T", v)
}
excludedEntries := []string{"bootstrap_dns"}
@@ -511,3 +511,131 @@ func TestUpgradeSchema11to12(t *testing.T) {
assert.Equal(t, 90*24*time.Hour, ivlVal.Duration)
})
}
func TestUpgradeSchema12to13(t *testing.T) {
const newSchemaVer = 13
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{},
want: yobj{"schema_version": newSchemaVer},
name: "no_dns",
}, {
in: yobj{"dns": yobj{}},
want: yobj{
"dns": yobj{},
"schema_version": newSchemaVer,
},
name: "no_dhcp",
}, {
in: yobj{
"dns": yobj{
"local_domain_name": "lan",
},
"dhcp": yobj{},
"schema_version": newSchemaVer - 1,
},
want: yobj{
"dns": yobj{},
"dhcp": yobj{
"local_domain_name": "lan",
},
"schema_version": newSchemaVer,
},
name: "good",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema12to13(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema13to14(t *testing.T) {
const newSchemaVer = 14
testClient := &clientObject{
Name: "agh-client",
IDs: []string{"id1"},
UseGlobalSettings: true,
}
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
// The clients field will be added anyway.
"clients": yobj{
"persistent": yarr{},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: false,
DHCP: true,
HostsFile: true,
},
},
},
name: "no_clients",
}, {
in: yobj{
"clients": []*clientObject{testClient},
},
want: yobj{
"schema_version": newSchemaVer,
"clients": yobj{
"persistent": []*clientObject{testClient},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: false,
DHCP: true,
HostsFile: true,
},
},
},
name: "no_dns",
}, {
in: yobj{
"clients": []*clientObject{testClient},
"dns": yobj{
"resolve_clients": true,
},
},
want: yobj{
"schema_version": newSchemaVer,
"clients": yobj{
"persistent": []*clientObject{testClient},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: true,
DHCP: true,
HostsFile: true,
},
},
"dns": yobj{},
},
name: "good",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema13to14(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}

View File

@@ -34,14 +34,13 @@ const (
)
type webConfig struct {
clientFS fs.FS
clientBetaFS fs.FS
BindHost net.IP
BindPort int
BetaBindPort int
PortHTTPS int
firstRun bool
clientFS fs.FS
clientBetaFS fs.FS
// ReadTimeout is an option to pass to http.Server for setting an
// appropriate field.
@@ -54,6 +53,8 @@ type webConfig struct {
// WriteTimeout is an option to pass to http.Server for setting an
// appropriate field.
WriteTimeout time.Duration
firstRun bool
}
// HTTPSServer - HTTPS Server