Pull request 2405: AGDNS-2374-updater-slog
Squashed commit of the following:
commit 89c3df471964b674b7ddafeb22566e5be9b56a13
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Mon May 12 18:59:39 2025 +0300
updater: imp log
commit d78ba4368027ddcbb41c10fbf09d43fe0721dc4c
Merge: 68410954c 187b759fc
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Mon May 12 18:53:33 2025 +0300
Merge branch 'master' into AGDNS-2374-updater-slog
commit 68410954c80d76b2adafe4ed28fafdd6b6b6daae
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Wed Apr 30 15:54:30 2025 +0300
updater: imp docs
commit 99a705218fb849bb59dee5b801c5279a501bcf98
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Wed Apr 30 15:40:30 2025 +0300
updater: imp docs, logs
commit 2a83ee3ebf9610a2703d99ec6a6b327a315f6cce
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date: Tue Apr 29 21:01:02 2025 +0300
updater: use slog
This commit is contained in:
@@ -82,7 +82,7 @@ func (web *webAPI) requestVersionInfo(
|
||||
) (err error) {
|
||||
updater := web.conf.updater
|
||||
for range 3 {
|
||||
resp.VersionInfo, err = updater.VersionInfo(recheck)
|
||||
resp.VersionInfo, err = updater.VersionInfo(ctx, recheck)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = updater.Update(false)
|
||||
err = updater.Update(r.Context(), false)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
|
||||
@@ -487,9 +487,14 @@ func checkPorts() (err error) {
|
||||
}
|
||||
|
||||
// isUpdateEnabled returns true if the update is enabled for current
|
||||
// configuration. It also logs the decision. customURL should be true if the
|
||||
// configuration. It also logs the decision. isCustomURL should be true if the
|
||||
// updater is using a custom URL.
|
||||
func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customURL bool) (ok bool) {
|
||||
func isUpdateEnabled(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
opts *options,
|
||||
isCustomURL bool,
|
||||
) (ok bool) {
|
||||
if opts.disableUpdate {
|
||||
l.DebugContext(ctx, "updates are disabled by command-line option")
|
||||
|
||||
@@ -500,13 +505,13 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
|
||||
case
|
||||
version.ChannelDevelopment,
|
||||
version.ChannelCandidate:
|
||||
if customURL {
|
||||
if isCustomURL {
|
||||
l.DebugContext(ctx, "updates are enabled because custom url is used")
|
||||
} else {
|
||||
l.DebugContext(ctx, "updates are disabled for development and candidate builds")
|
||||
}
|
||||
|
||||
return customURL
|
||||
return isCustomURL
|
||||
default:
|
||||
l.DebugContext(ctx, "updates are enabled")
|
||||
|
||||
@@ -514,7 +519,7 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
|
||||
}
|
||||
}
|
||||
|
||||
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
|
||||
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
|
||||
// nil.
|
||||
func initWeb(
|
||||
ctx context.Context,
|
||||
@@ -523,7 +528,7 @@ func initWeb(
|
||||
upd *updater.Updater,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
customURL bool,
|
||||
isCustomUpdURL bool,
|
||||
) (web *webAPI, err error) {
|
||||
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
||||
|
||||
@@ -539,7 +544,7 @@ func initWeb(
|
||||
}
|
||||
}
|
||||
|
||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, customURL)
|
||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL)
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
@@ -645,11 +650,12 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
|
||||
confPath := configFilePath()
|
||||
|
||||
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
|
||||
updLogger := slogLogger.With(slogutil.KeyPrefix, "updater")
|
||||
upd, isCustomURL := newUpdater(ctx, updLogger, config, globalContext.workDir, confPath, execPath)
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr)
|
||||
cmdlineUpdate(ctx, updLogger, opts, upd, tlsMgr)
|
||||
|
||||
if !globalContext.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -671,7 +677,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
globalContext.auth, err = initUsers()
|
||||
fatalOnError(err)
|
||||
|
||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, isCustomURL)
|
||||
fatalOnError(err)
|
||||
|
||||
globalContext.web = web
|
||||
@@ -714,16 +720,17 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
<-done
|
||||
}
|
||||
|
||||
// newUpdater creates a new AdGuard Home updater. customURL is true if the user
|
||||
// has specified a custom version announcement URL.
|
||||
// newUpdater creates a new AdGuard Home updater. l and conf must not be nil.
|
||||
// workDir, confPath, and execPath must not be empty. isCustomURL is true if
|
||||
// the user has specified a custom version announcement URL.
|
||||
func newUpdater(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
conf *configuration,
|
||||
workDir string,
|
||||
confPath string,
|
||||
execPath string,
|
||||
config *configuration,
|
||||
) (upd *updater.Updater, customURL bool) {
|
||||
) (upd *updater.Updater, isCustomURL bool) {
|
||||
// envName is the name of the environment variable that can be used to
|
||||
// override the default version check URL.
|
||||
const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL"
|
||||
@@ -735,14 +742,14 @@ func newUpdater(
|
||||
case version.Channel() == version.ChannelRelease:
|
||||
// Only enable custom version URL for development builds.
|
||||
l.DebugContext(ctx, "custom version url is disabled for release builds")
|
||||
case !config.UnsafeUseCustomUpdateIndexURL:
|
||||
case !conf.UnsafeUseCustomUpdateIndexURL:
|
||||
l.DebugContext(ctx, "custom version url is disabled in config")
|
||||
default:
|
||||
versionURL, _ = url.Parse(customURLStr)
|
||||
}
|
||||
|
||||
err := urlutil.ValidateHTTPURL(versionURL)
|
||||
if customURL = err == nil; !customURL {
|
||||
if isCustomURL = err == nil; !isCustomURL {
|
||||
l.DebugContext(ctx, "parsing custom version url", slogutil.KeyError, err)
|
||||
|
||||
versionURL = updater.DefaultVersionURL()
|
||||
@@ -751,7 +758,8 @@ func newUpdater(
|
||||
l.DebugContext(ctx, "creating updater", "config_path", confPath)
|
||||
|
||||
return updater.NewUpdater(&updater.Config{
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Client: conf.Filtering.HTTPClient,
|
||||
Logger: l,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
@@ -762,7 +770,7 @@ func newUpdater(
|
||||
ConfName: confPath,
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: versionURL,
|
||||
}), customURL
|
||||
}), isCustomURL
|
||||
}
|
||||
|
||||
// checkPermissions checks and migrates permissions of the files and directories
|
||||
@@ -1083,7 +1091,7 @@ func cmdlineUpdate(
|
||||
|
||||
l.InfoContext(ctx, "performing update via cli")
|
||||
|
||||
info, err := upd.VersionInfo(true)
|
||||
info, err := upd.VersionInfo(ctx, true)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "getting version info", slogutil.KeyError, err)
|
||||
|
||||
@@ -1096,7 +1104,7 @@ func cmdlineUpdate(
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
err = upd.Update(globalContext.firstRun)
|
||||
err = upd.Update(ctx, globalContext.firstRun)
|
||||
fatalOnError(err)
|
||||
|
||||
err = restartService()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB
|
||||
|
||||
// VersionInfo downloads the latest version information. If forceRecheck is
|
||||
// false and there are cached results, those results are returned.
|
||||
func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
func (u *Updater) VersionInfo(ctx context.Context, forceRecheck bool) (vi VersionInfo, err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
@@ -45,11 +45,17 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
return u.prevCheckResult, u.prevCheckError
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
vcu := u.versionCheckURL
|
||||
resp, err = u.client.Get(vcu)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, vcu, nil)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
return VersionInfo{}, fmt.Errorf("constructing request to %s: %w", vcu, err)
|
||||
}
|
||||
|
||||
u.logger.DebugContext(ctx, "requesting version data", "url", vcu)
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("requesting %s: %w", vcu, err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
@@ -59,16 +65,16 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
// ReadCloser.
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
return VersionInfo{}, fmt.Errorf("reading response from %s: %w", vcu, err)
|
||||
}
|
||||
|
||||
u.prevCheckTime = now
|
||||
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body)
|
||||
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(ctx, body)
|
||||
|
||||
return u.prevCheckResult, u.prevCheckError
|
||||
}
|
||||
|
||||
func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
func (u *Updater) parseVersionResponse(ctx context.Context, data []byte) (VersionInfo, error) {
|
||||
info := VersionInfo{
|
||||
CanAutoUpdate: aghalg.NBFalse,
|
||||
}
|
||||
@@ -92,7 +98,7 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
info.Announcement = versionJSON["announcement"]
|
||||
info.AnnouncementURL = versionJSON["announcement_url"]
|
||||
|
||||
packageURL, key, found := u.downloadURL(versionJSON)
|
||||
packageURL, key, found := u.downloadURL(ctx, versionJSON)
|
||||
if !found {
|
||||
return info, fmt.Errorf("version.json: no package URL: key %q not found in object", key)
|
||||
}
|
||||
@@ -108,7 +114,10 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
// downloadURL returns the download URL for current build as well as its key in
|
||||
// versionObj. If the key is not found, it additionally prints an informative
|
||||
// log message.
|
||||
func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string, ok bool) {
|
||||
func (u *Updater) downloadURL(
|
||||
ctx context.Context,
|
||||
versionObj map[string]string,
|
||||
) (dlURL, key string, ok bool) {
|
||||
if u.goarch == "arm" && u.goarm != "" {
|
||||
key = fmt.Sprintf("download_%s_%sv%s", u.goos, u.goarch, u.goarm)
|
||||
} else if isMIPS(u.goarch) && u.gomips != "" {
|
||||
@@ -124,7 +133,7 @@ func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string,
|
||||
|
||||
keys := slices.Sorted(maps.Keys(versionObj))
|
||||
|
||||
log.Error("updater: key %q not found; got keys %q", key, keys)
|
||||
u.logger.ErrorContext(ctx, "key not found", "missing", key, "got", keys)
|
||||
|
||||
return "", key, false
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -58,6 +59,7 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
@@ -65,7 +67,8 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
info, err := u.VersionInfo(ctx, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 1)
|
||||
@@ -75,14 +78,14 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
|
||||
t.Run("cache_check", func(t *testing.T) {
|
||||
_, err = u.VersionInfo(false)
|
||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 1)
|
||||
})
|
||||
|
||||
t.Run("force_check", func(t *testing.T) {
|
||||
_, err = u.VersionInfo(true)
|
||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, counter, 2)
|
||||
@@ -91,7 +94,7 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
t.Run("api_fail", func(t *testing.T) {
|
||||
srv.Close()
|
||||
|
||||
_, err = u.VersionInfo(true)
|
||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true)
|
||||
var urlErr *url.Error
|
||||
assert.ErrorAs(t, err, &urlErr)
|
||||
})
|
||||
@@ -130,6 +133,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: fakeClient,
|
||||
Logger: testLogger,
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOOS: "linux",
|
||||
@@ -139,7 +143,8 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
info, err := u.VersionInfo(ctx, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -22,13 +24,14 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// Updater is the AdGuard Home updater.
|
||||
type Updater struct {
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
|
||||
version string
|
||||
channel string
|
||||
@@ -75,27 +78,48 @@ func DefaultVersionURL() *url.URL {
|
||||
|
||||
// Config is the AdGuard Home updater configuration.
|
||||
type Config struct {
|
||||
// Client is used to perform HTTP requests. It must not be nil.
|
||||
Client *http.Client
|
||||
|
||||
// Logger is used for logging the update process. It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// VersionCheckURL is URL to the latest version announcement. It must not
|
||||
// be nil, see [DefaultVersionURL].
|
||||
VersionCheckURL *url.URL
|
||||
|
||||
// Version is the current AdGuard Home version. It must not be empty.
|
||||
Version string
|
||||
Channel string
|
||||
GOARCH string
|
||||
GOOS string
|
||||
GOARM string
|
||||
GOMIPS string
|
||||
|
||||
// ConfName is the name of the current configuration file. Typically,
|
||||
// "AdGuardHome.yaml".
|
||||
// Channel is the current AdGuard Home update channel. It must be a valid
|
||||
// channel, see [version.ChannelBeta] and the related constants.
|
||||
Channel string
|
||||
|
||||
// GOARCH is the current CPU architecture. It must not be empty and must be
|
||||
// one of the supported architectures.
|
||||
GOARCH string
|
||||
|
||||
// GOOS is the current operating system. It must not be empty and must be
|
||||
// one of the supported OSs.
|
||||
GOOS string
|
||||
|
||||
// GOARM is the current ARM variant, if any. It must either be empty or be
|
||||
// a valid and supported GOARM value.
|
||||
GOARM string
|
||||
|
||||
// GOMIPS is the current MIPS variant, if any. It must either be empty or
|
||||
// be a valid and supported GOMIPS value.
|
||||
GOMIPS string
|
||||
|
||||
// ConfName is the name of the current configuration file. It must not be
|
||||
// empty.
|
||||
ConfName string
|
||||
|
||||
// WorkDir is the working directory that is used for temporary files.
|
||||
// WorkDir is the working directory that is used for temporary files. It
|
||||
// must not be empty.
|
||||
WorkDir string
|
||||
|
||||
// ExecPath is path to the executable file.
|
||||
// ExecPath is path to the executable file. It must not be empty.
|
||||
ExecPath string
|
||||
}
|
||||
|
||||
@@ -103,6 +127,7 @@ type Config struct {
|
||||
func NewUpdater(conf *Config) *Updater {
|
||||
return &Updater{
|
||||
client: conf.Client,
|
||||
logger: conf.Logger,
|
||||
|
||||
version: conf.Version,
|
||||
channel: conf.Channel,
|
||||
@@ -122,49 +147,49 @@ func NewUpdater(conf *Config) *Updater {
|
||||
|
||||
// Update performs the auto-update. It returns an error if the update failed.
|
||||
// If firstRun is true, it assumes the configuration file doesn't exist.
|
||||
func (u *Updater) Update(firstRun bool) (err error) {
|
||||
func (u *Updater) Update(ctx context.Context, firstRun bool) (err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
log.Info("updater: updating")
|
||||
u.logger.InfoContext(ctx, "staring update", "first_run", firstRun)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Info("updater: failed")
|
||||
u.logger.ErrorContext(ctx, "update failed", slogutil.KeyError, err)
|
||||
} else {
|
||||
log.Info("updater: finished successfully")
|
||||
u.logger.InfoContext(ctx, "update finished")
|
||||
}
|
||||
}()
|
||||
|
||||
err = u.prepare()
|
||||
err = u.prepare(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing: %w", err)
|
||||
}
|
||||
|
||||
defer u.clean()
|
||||
defer u.clean(ctx)
|
||||
|
||||
err = u.downloadPackageFile()
|
||||
err = u.downloadPackageFile(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("downloading package file: %w", err)
|
||||
}
|
||||
|
||||
err = u.unpack()
|
||||
err = u.unpack(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpacking: %w", err)
|
||||
}
|
||||
|
||||
if !firstRun {
|
||||
err = u.check()
|
||||
err = u.check(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = u.backup(firstRun)
|
||||
err = u.backup(ctx, firstRun)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making backup: %w", err)
|
||||
}
|
||||
|
||||
err = u.replace()
|
||||
err = u.replace(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing: %w", err)
|
||||
}
|
||||
@@ -181,7 +206,7 @@ func (u *Updater) NewVersion() (nv string) {
|
||||
}
|
||||
|
||||
// prepare fills all necessary fields in Updater object.
|
||||
func (u *Updater) prepare() (err error) {
|
||||
func (u *Updater) prepare(ctx context.Context) (err error) {
|
||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
@@ -200,11 +225,12 @@ func (u *Updater) prepare() (err error) {
|
||||
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath))
|
||||
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
|
||||
|
||||
log.Debug(
|
||||
"updater: updating from %s to %s using url: %s",
|
||||
version.Version(),
|
||||
u.newVersion,
|
||||
u.packageURL,
|
||||
u.logger.InfoContext(
|
||||
ctx,
|
||||
"updating",
|
||||
"from", version.Version(),
|
||||
"to", u.newVersion,
|
||||
"package_url", u.packageURL,
|
||||
)
|
||||
|
||||
u.currentExeName = u.execPath
|
||||
@@ -217,23 +243,20 @@ func (u *Updater) prepare() (err error) {
|
||||
}
|
||||
|
||||
// unpack extracts the files from the downloaded archive.
|
||||
func (u *Updater) unpack() error {
|
||||
var err error
|
||||
func (u *Updater) unpack(ctx context.Context) (err error) {
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
|
||||
log.Debug("updater: unpacking package")
|
||||
u.logger.InfoContext(ctx, "unpacking package", "package_name", pkgNameOnly)
|
||||
if strings.HasSuffix(pkgNameOnly, ".zip") {
|
||||
u.unpackedFiles, err = zipFileUnpack(u.packageName, u.updateDir)
|
||||
u.unpackedFiles, err = u.unpackZip(ctx, u.packageName, u.updateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf(".zip unpack failed: %w", err)
|
||||
}
|
||||
|
||||
} else if strings.HasSuffix(pkgNameOnly, ".tar.gz") {
|
||||
u.unpackedFiles, err = tarGzFileUnpack(u.packageName, u.updateDir)
|
||||
u.unpackedFiles, err = u.unpackTarGz(ctx, u.packageName, u.updateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf(".tar.gz unpack failed: %w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
return fmt.Errorf("unknown package extension")
|
||||
}
|
||||
@@ -243,8 +266,8 @@ func (u *Updater) unpack() error {
|
||||
|
||||
// check returns an error if the configuration file couldn't be used with the
|
||||
// version of AdGuard Home just downloaded.
|
||||
func (u *Updater) check() (err error) {
|
||||
log.Debug("updater: checking configuration")
|
||||
func (u *Updater) check(ctx context.Context) (err error) {
|
||||
u.logger.InfoContext(ctx, "checking configuration")
|
||||
|
||||
err = copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
@@ -268,8 +291,9 @@ func (u *Updater) check() (err error) {
|
||||
|
||||
// backup makes a backup of the current configuration and supporting files. It
|
||||
// ignores the configuration file if firstRun is true.
|
||||
func (u *Updater) backup(firstRun bool) (err error) {
|
||||
log.Debug("updater: backing up current configuration")
|
||||
func (u *Updater) backup(ctx context.Context, firstRun bool) (err error) {
|
||||
u.logger.InfoContext(ctx, "backing up current configuration")
|
||||
|
||||
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||
if !firstRun {
|
||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||
@@ -279,7 +303,7 @@ func (u *Updater) backup(firstRun bool) (err error) {
|
||||
}
|
||||
|
||||
wd := u.workDir
|
||||
err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir)
|
||||
err = u.copySupportingFiles(ctx, u.unpackedFiles, wd, u.backupDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", wd, u.backupDir, err)
|
||||
}
|
||||
@@ -289,13 +313,18 @@ func (u *Updater) backup(firstRun bool) (err error) {
|
||||
|
||||
// replace moves the current executable with the updated one and also copies the
|
||||
// supporting files.
|
||||
func (u *Updater) replace() error {
|
||||
err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir)
|
||||
func (u *Updater) replace(ctx context.Context) (err error) {
|
||||
err = u.copySupportingFiles(ctx, u.unpackedFiles, u.updateDir, u.workDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", u.updateDir, u.workDir, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: renaming: %s to %s", u.currentExeName, u.backupExeName)
|
||||
u.logger.InfoContext(
|
||||
ctx,
|
||||
"backing up current executable",
|
||||
"from", u.currentExeName,
|
||||
"to", u.backupExeName,
|
||||
)
|
||||
err = os.Rename(u.currentExeName, u.backupExeName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -311,14 +340,22 @@ func (u *Updater) replace() error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("updater: renamed: %s to %s", u.updateExeName, u.currentExeName)
|
||||
u.logger.InfoContext(
|
||||
ctx,
|
||||
"replacing current executable",
|
||||
"from", u.updateExeName,
|
||||
"to", u.currentExeName,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clean removes the temporary directory itself and all it's contents.
|
||||
func (u *Updater) clean() {
|
||||
_ = os.RemoveAll(u.updateDir)
|
||||
func (u *Updater) clean(ctx context.Context) {
|
||||
err := os.RemoveAll(u.updateDir)
|
||||
if err != nil {
|
||||
u.logger.WarnContext(ctx, "removing update dir", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// MaxPackageFileSize is a maximum package file length in bytes. The largest
|
||||
@@ -327,34 +364,52 @@ func (u *Updater) clean() {
|
||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||
|
||||
// Download package file and save it to disk
|
||||
func (u *Updater) downloadPackageFile() (err error) {
|
||||
var resp *http.Response
|
||||
resp, err = u.client.Get(u.packageURL)
|
||||
func (u *Updater) downloadPackageFile(ctx context.Context) (err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.packageURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
return fmt.Errorf("constructing package request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("requesting package: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
r := ioutil.LimitReader(resp.Body, MaxPackageFileSize)
|
||||
|
||||
log.Debug("updater: reading http body")
|
||||
u.logger.InfoContext(ctx, "reading http body")
|
||||
|
||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
err = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
if err != nil {
|
||||
// TODO(a.garipov): Consider returning this error.
|
||||
u.logger.WarnContext(ctx, "creating update dir", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
u.logger.InfoContext(ctx, "saving package", "to", u.packageName)
|
||||
|
||||
log.Debug("updater: saving package to file")
|
||||
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing package file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) {
|
||||
// unpackTarGzFile unpacks one file from a .tar.gz archive into outDir. All
|
||||
// arguments must not be empty.
|
||||
func (u *Updater) unpackTarGzFile(
|
||||
ctx context.Context,
|
||||
outDir string,
|
||||
tr *tar.Reader,
|
||||
hdr *tar.Header,
|
||||
) (name string, err error) {
|
||||
name = filepath.Base(hdr.Name)
|
||||
if name == "" {
|
||||
return "", nil
|
||||
@@ -377,13 +432,18 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", fmt.Errorf("creating directory %q: %w", outName, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outName)
|
||||
u.logger.InfoContext(ctx, "created directory", "name", outName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
log.Info("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
|
||||
u.logger.WarnContext(
|
||||
ctx,
|
||||
"unknown file type; skipping",
|
||||
"file_name", name,
|
||||
"type", hdr.Typeflag,
|
||||
)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -400,16 +460,19 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created file %q", outName)
|
||||
u.logger.InfoContext(ctx, "created file", "name", outName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Unpack all files from .tar.gz file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// All files are created inside outDir, subdirectories are not created
|
||||
// Return the list of files (not directories) written
|
||||
func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
|
||||
// unpackTarGz unpack all files from a .tar.gz archive to outDir. Existing
|
||||
// files are overwritten. All files are created inside outDir. files are the
|
||||
// list of created files.
|
||||
func (u *Updater) unpackTarGz(
|
||||
ctx context.Context,
|
||||
tarfile string,
|
||||
outDir string,
|
||||
) (files []string, err error) {
|
||||
f, err := os.Open(tarfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("os.Open(): %w", err)
|
||||
@@ -437,7 +500,7 @@ func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
|
||||
}
|
||||
|
||||
var name string
|
||||
name, err = tarGzFileUnpackOne(outDir, tarReader, hdr)
|
||||
name, err = u.unpackTarGzFile(ctx, outDir, tarReader, hdr)
|
||||
|
||||
if name != "" {
|
||||
files = append(files, name)
|
||||
@@ -447,7 +510,13 @@ func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
|
||||
return files, err
|
||||
}
|
||||
|
||||
func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
// unpackZipFile unpacks one file from a .zip archive into outDir. All
|
||||
// arguments must not be empty.
|
||||
func (u *Updater) unpackZipFile(
|
||||
ctx context.Context,
|
||||
outDir string,
|
||||
zf *zip.File,
|
||||
) (name string, err error) {
|
||||
var rc io.ReadCloser
|
||||
rc, err = zf.Open()
|
||||
if err != nil {
|
||||
@@ -466,7 +535,8 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
if name == "AdGuardHome" {
|
||||
// Top-level AdGuardHome/. Skip it.
|
||||
//
|
||||
// TODO(a.garipov): See the similar todo in tarGzFileUnpack.
|
||||
// TODO(a.garipov): See the similar TODO in
|
||||
// [Updater.unpackTarGzFile].
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -475,7 +545,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
u.logger.InfoContext(ctx, "created directory", "name", outputName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -492,16 +562,19 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created file %q", outputName)
|
||||
u.logger.InfoContext(ctx, "created file", "name", outputName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Unpack all files from .zip file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// All files are created inside 'outDir', subdirectories are not created
|
||||
// Return the list of files (not directories) written
|
||||
func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
|
||||
// unpackZip unpack all files from a .zip archive to outDir. Existing files are
|
||||
// overwritten. All files are created inside outDir. files are the list of
|
||||
// created files.
|
||||
func (u *Updater) unpackZip(
|
||||
ctx context.Context,
|
||||
zipfile string,
|
||||
outDir string,
|
||||
) (files []string, err error) {
|
||||
zrc, err := zip.OpenReader(zipfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
|
||||
@@ -510,7 +583,7 @@ func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
|
||||
|
||||
for _, zf := range zrc.File {
|
||||
var name string
|
||||
name, err = zipFileUnpackOne(outDir, zf)
|
||||
name, err = u.unpackZipFile(ctx, outDir, zf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
@@ -543,7 +616,12 @@ func copyFile(src, dst string, perm fs.FileMode) (err error) {
|
||||
// copySupportingFiles copies each file specified in files from srcdir to
|
||||
// dstdir. If a file specified as a path, only the name of the file is used.
|
||||
// It skips AdGuardHome, AdGuardHome.exe, and AdGuardHome.yaml.
|
||||
func copySupportingFiles(files []string, srcdir, dstdir string) error {
|
||||
func (u *Updater) copySupportingFiles(
|
||||
ctx context.Context,
|
||||
files []string,
|
||||
srcdir string,
|
||||
dstdir string,
|
||||
) (err error) {
|
||||
for _, f := range files {
|
||||
_, name := filepath.Split(f)
|
||||
if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" {
|
||||
@@ -553,12 +631,12 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error {
|
||||
src := filepath.Join(srcdir, name)
|
||||
dst := filepath.Join(dstdir, name)
|
||||
|
||||
err := copyFile(src, dst, aghos.DefaultPermFile)
|
||||
err = copyFile(src, dst, aghos.DefaultPermFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("updater: copied: %q to %q", src, dst)
|
||||
u.logger.InfoContext(ctx, "copied", "from", src, "to", dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -55,6 +59,7 @@ func TestUpdater_internal(t *testing.T) {
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: fakeClient,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
GOOS: tc.os,
|
||||
Version: "v0.103.0",
|
||||
ExecPath: exePath,
|
||||
@@ -68,13 +73,13 @@ func TestUpdater_internal(t *testing.T) {
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare())
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
require.NoError(t, u.prepare(newCtx(t)))
|
||||
require.NoError(t, u.downloadPackageFile(newCtx(t)))
|
||||
require.NoError(t, u.unpack(newCtx(t)))
|
||||
require.NoError(t, u.backup(newCtx(t), false))
|
||||
require.NoError(t, u.replace(newCtx(t)))
|
||||
|
||||
u.clean()
|
||||
u.clean(newCtx(t))
|
||||
|
||||
require.True(t, t.Run("backup", func(t *testing.T) {
|
||||
var d []byte
|
||||
@@ -113,3 +118,8 @@ func TestUpdater_internal(t *testing.T) {
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// newCtx is a helper that returns a new context with a timeout.
|
||||
func newCtx(tb testing.TB) (ctx context.Context) {
|
||||
return testutil.ContextWithTimeout(tb, 1*time.Second)
|
||||
}
|
||||
|
||||
@@ -10,17 +10,21 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testLogger is the common logger for tests.
|
||||
var testLogger = slogutil.NewDiscardLogger()
|
||||
|
||||
func TestUpdater_Update(t *testing.T) {
|
||||
const jsonData = `{
|
||||
@@ -73,6 +77,7 @@ func TestUpdater_Update(t *testing.T) {
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
GOARCH: "amd64",
|
||||
GOOS: "linux",
|
||||
Version: "v0.103.0",
|
||||
@@ -82,10 +87,12 @@ func TestUpdater_Update(t *testing.T) {
|
||||
VersionCheckURL: versionCheckURL,
|
||||
})
|
||||
|
||||
_, err = u.VersionInfo(false)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
_, err = u.VersionInfo(ctx, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = u.Update(true)
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err = u.Update(ctx, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// check backup files
|
||||
@@ -124,14 +131,15 @@ func TestUpdater_Update(t *testing.T) {
|
||||
t.Skip("skipping config check test on windows")
|
||||
}
|
||||
|
||||
err = u.Update(false)
|
||||
err = u.Update(testutil.ContextWithTimeout(t, testTimeout), false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("api_fail", func(t *testing.T) {
|
||||
srv.Close()
|
||||
|
||||
err = u.Update(true)
|
||||
err = u.Update(testutil.ContextWithTimeout(t, testTimeout), true)
|
||||
|
||||
var urlErr *url.Error
|
||||
assert.ErrorAs(t, err, &urlErr)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user