all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-06-07 20:04:01 +03:00
parent 7030c7c24c
commit c65700923a
76 changed files with 2998 additions and 1909 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/stringutil"
)
// Client contains information about persistent clients.
@@ -37,6 +38,19 @@ type Client struct {
IgnoreStatistics bool
}
// ShallowClone returns a deep copy of the client, except upstreamConfig,
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
func (c *Client) ShallowClone() (sh *Client) {
clone := *c
clone.IDs = stringutil.CloneSlice(c.IDs)
clone.Tags = stringutil.CloneSlice(c.Tags)
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
return &clone
}
// closeUpstreams closes the client-specific upstream config of c if any.
func (c *Client) closeUpstreams() (err error) {
if c.upstreamConfig != nil {

View File

@@ -378,6 +378,7 @@ func (clients *clientsContainer) clientOrArtificial(
}, true
}
// Find returns a shallow copy of the client if there is one found.
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -387,20 +388,18 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return nil, false
}
c.IDs = stringutil.CloneSlice(c.IDs)
c.Tags = stringutil.CloneSlice(c.Tags)
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
return c, true
return c.ShallowClone(), true
}
// shouldCountClient is a wrapper around Find to make it a valid client
// information finder for the statistics. If no information about the client
// is found, it returns true.
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
for _, id := range ids {
client, ok := clients.Find(id)
client, ok := clients.findLocked(id)
if ok {
return !client.IgnoreStatistics
}
@@ -617,6 +616,15 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
}
}
clients.add(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
return true, nil
}
// add c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) add(c *Client) {
// update Name index
clients.list[c.Name] = c
@@ -624,10 +632,6 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
for _, id := range c.IDs {
clients.idIndex[id] = c
}
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
return true, nil
}
// Del removes a client. ok is false if there is no such client.
@@ -645,86 +649,53 @@ func (clients *clientsContainer) Del(name string) (ok bool) {
log.Error("client container: removing client %s: %s", name, err)
}
clients.del(c)
return true
}
// del removes c from the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) del(c *Client) {
// update Name index
delete(clients.list, name)
delete(clients.list, c.Name)
// update ID index
for _, id := range c.IDs {
delete(clients.idIndex, id)
}
return true
}
// Update updates a client by its name.
func (clients *clientsContainer) Update(name string, c *Client) (err error) {
func (clients *clientsContainer) Update(prev, c *Client) (err error) {
err = clients.check(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
prev, ok := clients.list[name]
if !ok {
return errors.Error("client not found")
}
// First, check the name index.
// Check the name index.
if prev.Name != c.Name {
_, ok = clients.list[c.Name]
_, ok := clients.list[c.Name]
if ok {
return errors.Error("client already exists")
}
}
// Second, update the ID index.
err = clients.updateIDIndex(prev, c.IDs)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
// Update name index.
if prev.Name != c.Name {
delete(clients.list, prev.Name)
clients.list[c.Name] = prev
}
// Update upstreams cache.
err = c.closeUpstreams()
if err != nil {
return err
}
*prev = *c
return nil
}
// updateIDIndex updates the ID index data for cli using the information from
// newIDs.
func (clients *clientsContainer) updateIDIndex(cli *Client, newIDs []string) (err error) {
if slices.Equal(cli.IDs, newIDs) {
return nil
}
for _, id := range newIDs {
existing, ok := clients.idIndex[id]
if ok && existing != cli {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
// Check the ID index.
if !slices.Equal(prev.IDs, c.IDs) {
for _, id := range c.IDs {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
}
}
// Update the IDs in the index.
for _, id := range cli.IDs {
delete(clients.idIndex, id)
}
for _, id := range newIDs {
clients.idIndex[id] = cli
}
clients.del(prev)
clients.add(c)
return nil
}

View File

@@ -98,22 +98,8 @@ func TestClients(t *testing.T) {
assert.False(t, ok)
})
t.Run("update_fail_name", func(t *testing.T) {
err := clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"},
Name: "client3",
})
require.Error(t, err)
err = clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"},
Name: "client2",
})
assert.Error(t, err)
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.Update("client1", &Client{
err := clients.Update(&Client{Name: "client1"}, &Client{
IDs: []string{"2.2.2.2"},
Name: "client1",
})
@@ -129,7 +115,10 @@ func TestClients(t *testing.T) {
cliNewIP = netip.MustParseAddr(cliNew)
)
err := clients.Update("client1", &Client{
prev, ok := clients.list["client1"]
require.True(t, ok)
err := clients.Update(prev, &Client{
IDs: []string{cliNew},
Name: "client1",
})
@@ -138,7 +127,10 @@ func TestClients(t *testing.T) {
assert.Equal(t, clients.clientSource(cliOldIP), ClientSourceNone)
assert.Equal(t, clients.clientSource(cliNewIP), ClientSourcePersistent)
err = clients.Update("client1", &Client{
prev, ok = clients.list["client1"]
require.True(t, ok)
err = clients.Update(prev, &Client{
IDs: []string{cliNew},
Name: "client1-renamed",
UseOwnSettings: true,

View File

@@ -289,7 +289,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
err = clients.Update(dj.Name, c)
err = clients.Update(prev, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@@ -399,19 +399,39 @@ func (c *configuration) getConfigFilename() string {
return configFile
}
// getLogSettings reads logging settings from the config file.
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
func getLogSettings() logSettings {
l := logSettings{}
// readLogSettings reads logging settings from the config file. We do it in a
// separate method in order to configure logger before the actual configuration
// is parsed and applied.
func readLogSettings() (ls *logSettings) {
ls = &logSettings{}
yamlFile, err := readConfigFile()
if err != nil {
return l
return ls
}
err = yaml.Unmarshal(yamlFile, &l)
err = yaml.Unmarshal(yamlFile, ls)
if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err)
}
return l
return ls
}
// validateBindHosts returns error if any of binding hosts from configuration is
// not a valid IP address.
func validateBindHosts(conf *configuration) (err error) {
if !conf.BindHost.IsValid() {
return errors.Error("bind_host is not a valid ip address")
}
for i, addr := range conf.DNS.BindHosts {
if !addr.IsValid() {
return fmt.Errorf("dns.bind_hosts at index %d is not a valid ip address", i)
}
}
return nil
}
// parseConfig loads configuration from the YAML file
@@ -425,6 +445,13 @@ func parseConfig() (err error) {
config.fileData = nil
err = yaml.Unmarshal(fileData, &config)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = validateBindHosts(config)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

View File

@@ -180,7 +180,7 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)

View File

@@ -26,15 +26,14 @@ type temporaryError interface {
Temporary() (ok bool)
}
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// handleVersionJSON is the handler for the POST /control/version.json HTTP API.
//
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
resp := &versionResponse{}
if Context.disableUpdate {
resp.Disabled = true
err := json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
return
}

View File

@@ -27,14 +27,17 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2"
)
@@ -143,7 +146,9 @@ func Main(clientBuildFS fs.FS) {
run(opts, clientBuildFS)
}
func setupContext(opts options) {
// setupContext initializes [Context] fields. It also reads and upgrades
// config file if necessary.
func setupContext(opts options) (err error) {
setupContextFlags(opts)
Context.tlsRoots = aghtls.SystemRootCAs()
@@ -160,10 +165,15 @@ func setupContext(opts options) {
},
}
Context.mux = http.NewServeMux()
if !Context.firstRun {
// Do the upgrade if necessary.
err := upgradeConfig()
fatalOnError(err)
err = upgradeConfig()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err)
@@ -179,11 +189,14 @@ func setupContext(opts options) {
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
fatalOnError(err)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
}
Context.mux = http.NewServeMux()
return nil
}
// setupContextFlags sets global flags and prints their status to the log.
@@ -285,25 +298,27 @@ func setupHostsContainer() (err error) {
return nil
}
func setupConfig(opts options) (err error) {
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
config.DNS.DnsfilterConf.SafeSearchConf,
"default",
config.DNS.DnsfilterConf.SafeSearchCacheSize,
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
)
// setupOpts sets up command-line options.
func setupOpts(opts options) (err error) {
err = setupBindOpts(opts)
if err != nil {
return fmt.Errorf("initializing safesearch: %w", err)
// Don't wrap the error, because it's informative enough as is.
return err
}
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile
}
return nil
}
// initContextClients initializes Context clients and related fields.
func initContextClients() (err error) {
err = setupDNSFilteringConf(config.DNS.DnsfilterConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
//lint:ignore SA1019 Migration is not over.
@@ -338,8 +353,19 @@ func setupConfig(opts options) (err error) {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf)
Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
)
return nil
}
// setupBindOpts overrides bind host/port from the opts.
func setupBindOpts(opts options) (err error) {
if opts.bindPort != 0 {
config.BindPort = opts.bindPort
@@ -350,12 +376,83 @@ func setupConfig(opts options) (err error) {
}
}
// override bind host/port from the console
if opts.bindHost.IsValid() {
config.BindHost = opts.bindHost
}
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile
return nil
}
// setupDNSFilteringConf sets up DNS filtering configuration settings.
func setupDNSFilteringConf(conf *filtering.Config) (err error) {
const (
dnsTimeout = 3 * time.Second
sbService = "safe browsing"
defaultSafeBrowsingServer = `https://family.adguard-dns.com/dns-query`
sbTXTSuffix = `sb.dns.adguard.com.`
pcService = "parental control"
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
pcTXTSuffix = `pc.dns.adguard.com.`
)
conf.EtcHosts = Context.etcHosts
conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir()
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules)
conf.HTTPClient = Context.client
cacheTime := time.Duration(conf.CacheTime) * time.Minute
upsOpts := &upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
{94, 140, 14, 15},
{94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"),
},
}
sbUps, err := upstream.AddressToUpstream(defaultSafeBrowsingServer, upsOpts)
if err != nil {
return fmt.Errorf("converting safe browsing server: %w", err)
}
conf.SafeBrowsingChecker = hashprefix.New(&hashprefix.Config{
Upstream: sbUps,
ServiceName: sbService,
TXTSuffix: sbTXTSuffix,
CacheTime: cacheTime,
CacheSize: conf.SafeBrowsingCacheSize,
})
parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts)
if err != nil {
return fmt.Errorf("converting parental server: %w", err)
}
conf.ParentalControlChecker = hashprefix.New(&hashprefix.Config{
Upstream: parUps,
ServiceName: pcService,
TXTSuffix: pcTXTSuffix,
CacheTime: cacheTime,
CacheSize: conf.SafeBrowsingCacheSize,
})
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf,
"default",
conf.SafeSearchCacheSize,
cacheTime,
)
if err != nil {
return fmt.Errorf("initializing safesearch: %w", err)
}
return nil
@@ -432,14 +529,16 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home.
func run(opts options, clientBuildFS fs.FS) {
// configure config filename
// Configure config filename.
initConfigFilename(opts)
// configure working dir and config path
initWorkingDir(opts)
// Configure working dir and config path.
err := initWorkingDir(opts)
fatalOnError(err)
// configure log level and output
configureLogger(opts)
// Configure log level and output.
err = configureLogger(opts)
fatalOnError(err)
// Print the first message after logger is configured.
log.Info(version.Full())
@@ -448,25 +547,29 @@ func run(opts options, clientBuildFS fs.FS) {
log.Info("AdGuard Home is running as a service")
}
setupContext(opts)
err := configureOS(config)
err = setupContext(opts)
fatalOnError(err)
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
// so we have to initialize filtering's static data first,
// but also avoid relying on automatic Go init() function
err = configureOS(config)
fatalOnError(err)
// Clients package uses filtering package's static data
// (filtering.BlockedSvcKnown()), so we have to initialize filtering static
// data first, but also to avoid relying on automatic Go init() function.
filtering.InitModule()
err = setupConfig(opts)
err = initContextClients()
fatalOnError(err)
// TODO(e.burkov): This could be made earlier, probably as the option's
err = setupOpts(opts)
fatalOnError(err)
// TODO(e.burkov): This could be made earlier, probably as the option's
// effect.
cmdlineUpdate(opts)
if !Context.firstRun {
// Save the updated config
// Save the updated config.
err = config.write()
fatalOnError(err)
@@ -476,33 +579,15 @@ func run(opts options, clientBuildFS fs.FS) {
}
}
err = os.MkdirAll(Context.getDataDir(), 0o755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
dir := Context.getDataDir()
err = os.MkdirAll(dir, 0o755)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dir))
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = opts.glinetMode
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
rateLimiter = newAuthRateLimiter(
time.Duration(config.AuthBlockMin)*time.Minute,
config.AuthAttempts,
)
} else {
log.Info("authratelimiter is disabled")
}
Context.auth = InitAuth(
sessFilename,
config.Users,
config.WebSessionTTLHours*60*60,
rateLimiter,
)
if Context.auth == nil {
log.Fatalf("Couldn't initialize Auth module")
}
config.Users = nil
// Init auth module.
Context.auth, err = initUsers()
fatalOnError(err)
Context.tls, err = newTLSManager(config.TLS)
if err != nil {
@@ -520,10 +605,10 @@ func run(opts options, clientBuildFS fs.FS) {
Context.tls.start()
go func() {
serr := startDNSServer()
if serr != nil {
sErr := startDNSServer()
if sErr != nil {
closeDNSServer()
fatalOnError(serr)
fatalOnError(sErr)
}
}()
@@ -537,10 +622,33 @@ func run(opts options, clientBuildFS fs.FS) {
Context.web.start()
// wait indefinitely for other go-routines to complete their job
// Wait indefinitely for other goroutines to complete their job.
select {}
}
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
blockDur := time.Duration(config.AuthBlockMin) * time.Minute
rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts)
} else {
log.Info("authratelimiter is disabled")
}
sessionTTL := config.WebSessionTTLHours * 60 * 60
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
if auth == nil {
return nil, errors.Error("initializing auth module failed")
}
config.Users = nil
return auth, nil
}
func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
var anonFunc aghnet.IPMutFunc
if c.DNS.AnonymizeClientIP {
@@ -613,22 +721,19 @@ func writePIDFile(fn string) bool {
return true
}
// initConfigFilename sets up context config file path. This file path can be
// overridden by command-line arguments, or is set to default.
func initConfigFilename(opts options) {
// config file path can be overridden by command-line arguments:
if opts.confFilename != "" {
Context.configFilename = opts.confFilename
} else {
// Default config file name
Context.configFilename = "AdGuardHome.yaml"
}
Context.configFilename = stringutil.Coalesce(opts.confFilename, "AdGuardHome.yaml")
}
// initWorkingDir initializes the workDir
// if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(opts options) {
// initWorkingDir initializes the workDir. If no command-line arguments are
// specified, the directory with the binary file is used.
func initWorkingDir(opts options) (err error) {
execPath, err := os.Executable()
if err != nil {
panic(err)
// Don't wrap the error, because it's informative enough as is.
return err
}
if opts.workDir != "" {
@@ -640,34 +745,20 @@ func initWorkingDir(opts options) {
workDir, err := filepath.EvalSymlinks(Context.workDir)
if err != nil {
panic(err)
// Don't wrap the error, because it's informative enough as is.
return err
}
Context.workDir = workDir
return nil
}
// configureLogger configures logger level and output
func configureLogger(opts options) {
ls := getLogSettings()
// configureLogger configures logger level and output.
func configureLogger(opts options) (err error) {
ls := getLogSettings(opts)
// command-line arguments can override config settings
if opts.verbose || config.Verbose {
ls.Verbose = true
}
if opts.logFile != "" {
ls.File = opts.logFile
} else if config.File != "" {
ls.File = config.File
}
// Handle default log settings overrides
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
// log.SetLevel(log.INFO) - default
// Configure logger level.
if ls.Verbose {
log.SetLevel(log.DEBUG)
}
@@ -676,38 +767,63 @@ func configureLogger(opts options) {
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output.
ls.File = configSyslog
}
// logs are written to stdout (default)
// Write logs to stdout by default.
if ls.File == "" {
return
return nil
}
if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := aghos.ConfigureSyslog(serviceName)
// Use syslog where it is possible and eventlog on Windows.
err = aghos.ConfigureSyslog(serviceName)
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
return fmt.Errorf("cannot initialize syslog: %w", err)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.Compress, // disabled by default
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize, // megabytes
MaxAge: ls.MaxAge, // days
})
return nil
}
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.Compress,
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize,
MaxAge: ls.MaxAge,
})
return nil
}
// getLogSettings returns a log settings object properly initialized from opts.
func getLogSettings(opts options) (ls *logSettings) {
ls = readLogSettings()
// Command-line arguments can override config settings.
if opts.verbose || config.Verbose {
ls.Verbose = true
}
ls.File = stringutil.Coalesce(opts.logFile, config.File, ls.File)
// Handle default log settings overrides.
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if
// nothing else is configured. Otherwise, we'll lose the log output.
ls.File = configSyslog
}
return ls
}
// cleanup stops and resets all the modules.

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"io/fs"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
@@ -84,14 +83,9 @@ func svcStatus(s service.Service) (status service.Status, err error) {
// On OpenWrt, the service utility may not exist. We use our service script
// directly in this case.
func svcAction(s service.Service, action string) (err error) {
if runtime.GOOS == "darwin" && action == "start" {
var exe string
if exe, err = os.Executable(); err != nil {
log.Error("starting service: getting executable path: %s", err)
} else if exe, err = filepath.EvalSymlinks(exe); err != nil {
log.Error("starting service: evaluating executable symlinks: %s", err)
} else if !strings.HasPrefix(exe, "/Applications/") {
log.Info("warning: service must be started from within the /Applications directory")
if action == "start" {
if err = aghos.PreCheckActionStart(); err != nil {
log.Error("starting service: %s", err)
}
}
@@ -99,8 +93,6 @@ func svcAction(s service.Service, action string) (err error) {
if err != nil && service.Platform() == "unix-systemv" &&
(action == "start" || action == "stop" || action == "restart") {
_, err = runInitdCommand(action)
return err
}
return err
@@ -224,6 +216,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
runOpts := opts
runOpts.serviceControlAction = "run"
svcConfig := &service.Config{
Name: serviceName,
DisplayName: serviceDisplayName,
@@ -233,35 +226,48 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
}
configureService(svcConfig)
prg := &program{
clientBuildFS: clientBuildFS,
opts: runOpts,
}
var s service.Service
if s, err = service.New(prg, svcConfig); err != nil {
s, err := service.New(&program{clientBuildFS: clientBuildFS, opts: runOpts}, svcConfig)
if err != nil {
log.Fatalf("service: initializing service: %s", err)
}
err = handleServiceCommand(s, action, opts)
if err != nil {
log.Fatalf("service: %s", err)
}
log.Printf(
"service: action %s has been done successfully on %s",
action,
service.ChosenSystem(),
)
}
// handleServiceCommand handles service command.
func handleServiceCommand(s service.Service, action string, opts options) (err error) {
switch action {
case "status":
handleServiceStatusCommand(s)
case "run":
if err = s.Run(); err != nil {
log.Fatalf("service: failed to run service: %s", err)
return fmt.Errorf("failed to run service: %w", err)
}
case "install":
initConfigFilename(opts)
initWorkingDir(opts)
if err = initWorkingDir(opts); err != nil {
return fmt.Errorf("failed to init working dir: %w", err)
}
handleServiceInstallCommand(s)
case "uninstall":
handleServiceUninstallCommand(s)
default:
if err = svcAction(s, action); err != nil {
log.Fatalf("service: executing action %q: %s", action, err)
return fmt.Errorf("executing action %q: %w", action, err)
}
}
log.Printf("service: action %s has been done successfully on %s", action, service.ChosenSystem())
return nil
}
// handleServiceStatusCommand handles service "status" command.

View File

@@ -172,9 +172,32 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
}
}()
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
err = loadCertificateChainData(tlsConf, status)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = loadPrivateKeyData(tlsConf, status)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = validateCertificates(
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
)
return errors.Annotate(err, "validating certificate pair: %w")
}
// loadCertificateChainData loads PEM-encoded certificates chain data to the
// TLS configuration.
func loadCertificateChainData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together")
@@ -190,6 +213,13 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
status.ValidCert = true
}
return nil
}
// loadPrivateKeyData loads PEM-encoded private key data to the TLS
// configuration.
func loadPrivateKeyData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.PrivateKeyPath != "" {
if tlsConf.PrivateKey != "" {
return errors.Error("private key data and file can't be set together")
@@ -203,16 +233,6 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
status.ValidKey = true
}
err = validateCertificates(
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
)
if err != nil {
return fmt.Errorf("validating certificate pair: %w", err)
}
return nil
}

View File

@@ -41,7 +41,8 @@ func upgradeConfig() error {
err = yaml.Unmarshal(body, &diskConf)
if err != nil {
log.Printf("Couldn't parse config file: %s", err)
log.Printf("parsing config file for upgrade: %s", err)
return err
}
@@ -293,71 +294,61 @@ func upgradeSchema4to5(diskConf yobj) error {
return nil
}
// clients:
// ...
// upgradeSchema5to6 performs the following changes:
//
// ip: 127.0.0.1
// mac: ...
// # BEFORE:
// 'clients':
// ...
// 'ip': 127.0.0.1
// 'mac': ...
//
// ->
//
// clients:
// ...
//
// ids:
// - 127.0.0.1
// - ...
// # AFTER:
// 'clients':
// ...
// 'ids':
// - 127.0.0.1
// - ...
func upgradeSchema5to6(diskConf yobj) error {
log.Printf("%s(): called", funcName())
log.Printf("Upgrade yaml: 5 to 6")
diskConf["schema_version"] = 6
clients, ok := diskConf["clients"]
clientsVal, ok := diskConf["clients"]
if !ok {
return nil
}
switch arr := clients.(type) {
case []any:
for i := range arr {
switch c := arr[i].(type) {
case map[any]any:
var ipVal any
ipVal, ok = c["ip"]
ids := []string{}
if ok {
var ip string
ip, ok = ipVal.(string)
if !ok {
log.Fatalf("client.ip is not a string: %v", ipVal)
return nil
}
if len(ip) != 0 {
ids = append(ids, ip)
}
}
clients, ok := clientsVal.([]yobj)
if !ok {
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
}
var macVal any
macVal, ok = c["mac"]
if ok {
var mac string
mac, ok = macVal.(string)
if !ok {
log.Fatalf("client.mac is not a string: %v", macVal)
return nil
}
if len(mac) != 0 {
ids = append(ids, mac)
}
}
for i := range clients {
c := clients[i]
var ids []string
c["ids"] = ids
default:
continue
if ipVal, hasIP := c["ip"]; hasIP {
var ip string
if ip, ok = ipVal.(string); !ok {
return fmt.Errorf("client.ip is not a string: %v", ipVal)
}
if ip != "" {
ids = append(ids, ip)
}
}
default:
return nil
if macVal, hasMac := c["mac"]; hasMac {
var mac string
if mac, ok = macVal.(string); !ok {
return fmt.Errorf("client.mac is not a string: %v", macVal)
}
if mac != "" {
ids = append(ids, mac)
}
}
c["ids"] = ids
}
return nil

View File

@@ -68,6 +68,95 @@ func TestUpgradeSchema2to3(t *testing.T) {
assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries)
}
func TestUpgradeSchema5to6(t *testing.T) {
const newSchemaVer = 6
testCases := []struct {
in yobj
want yobj
wantErr string
name string
}{{
in: yobj{
"clients": []yobj{},
},
want: yobj{
"clients": []yobj{},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "no_clients",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"127.0.0.1"},
"ip": "127.0.0.1",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_ip",
}, {
in: yobj{
"clients": []yobj{{"mac": "mac"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"mac"},
"mac": "mac",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_mac",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": "mac"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"127.0.0.1", "mac"},
"ip": "127.0.0.1",
"mac": "mac",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_ip_mac",
}, {
in: yobj{
"clients": []yobj{{"ip": 1, "mac": "mac"}},
},
want: yobj{
"clients": []yobj{{"ip": 1, "mac": "mac"}},
"schema_version": newSchemaVer,
},
wantErr: "client.ip is not a string: 1",
name: "inv_client_ip",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": 1}},
},
want: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": 1}},
"schema_version": newSchemaVer,
},
wantErr: "client.mac is not a string: 1",
name: "inv_client_mac",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema5to6(tc.in)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema7to8(t *testing.T) {
const host = "1.2.3.4"
oldConf := yobj{