From ae840c9c969968b8900336587232aa15009d1b76 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 13 May 2025 14:42:33 +0300 Subject: [PATCH] Pull request 2405: AGDNS-2374-updater-slog Squashed commit of the following: commit 89c3df471964b674b7ddafeb22566e5be9b56a13 Author: Ainar Garipov Date: Mon May 12 18:59:39 2025 +0300 updater: imp log commit d78ba4368027ddcbb41c10fbf09d43fe0721dc4c Merge: 68410954c 187b759fc Author: Ainar Garipov Date: Mon May 12 18:53:33 2025 +0300 Merge branch 'master' into AGDNS-2374-updater-slog commit 68410954c80d76b2adafe4ed28fafdd6b6b6daae Author: Ainar Garipov Date: Wed Apr 30 15:54:30 2025 +0300 updater: imp docs commit 99a705218fb849bb59dee5b801c5279a501bcf98 Author: Ainar Garipov Date: Wed Apr 30 15:40:30 2025 +0300 updater: imp docs, logs commit 2a83ee3ebf9610a2703d99ec6a6b327a315f6cce Author: Ainar Garipov Date: Tue Apr 29 21:01:02 2025 +0300 updater: use slog --- internal/home/controlupdate.go | 4 +- internal/home/home.go | 48 +++-- internal/updater/check.go | 31 +-- internal/updater/check_test.go | 15 +- internal/updater/updater.go | 228 +++++++++++++++------- internal/updater/updater_internal_test.go | 22 ++- internal/updater/updater_test.go | 22 ++- 7 files changed, 244 insertions(+), 126 deletions(-) diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 2974a1d1..4d390ec9 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -82,7 +82,7 @@ func (web *webAPI) requestVersionInfo( ) (err error) { updater := web.conf.updater for range 3 { - resp.VersionInfo, err = updater.VersionInfo(recheck) + resp.VersionInfo, err = updater.VersionInfo(ctx, recheck) if err == nil { return nil } @@ -133,7 +133,7 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) { return } - err = updater.Update(false) + err = updater.Update(r.Context(), false) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) diff --git a/internal/home/home.go b/internal/home/home.go index a4e847bf..052892f8 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -487,9 +487,14 @@ func checkPorts() (err error) { } // isUpdateEnabled returns true if the update is enabled for current -// configuration. It also logs the decision. customURL should be true if the +// configuration. It also logs the decision. isCustomURL should be true if the // updater is using a custom URL. -func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customURL bool) (ok bool) { +func isUpdateEnabled( + ctx context.Context, + l *slog.Logger, + opts *options, + isCustomURL bool, +) (ok bool) { if opts.disableUpdate { l.DebugContext(ctx, "updates are disabled by command-line option") @@ -500,13 +505,13 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU case version.ChannelDevelopment, version.ChannelCandidate: - if customURL { + if isCustomURL { l.DebugContext(ctx, "updates are enabled because custom url is used") } else { l.DebugContext(ctx, "updates are disabled for development and candidate builds") } - return customURL + return isCustomURL default: l.DebugContext(ctx, "updates are enabled") @@ -514,7 +519,7 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU } } -// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be +// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be // nil. func initWeb( ctx context.Context, @@ -523,7 +528,7 @@ func initWeb( upd *updater.Updater, baseLogger *slog.Logger, tlsMgr *tlsManager, - customURL bool, + isCustomUpdURL bool, ) (web *webAPI, err error) { logger := baseLogger.With(slogutil.KeyPrefix, "webapi") @@ -539,7 +544,7 @@ func initWeb( } } - disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, customURL) + disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL) webConf := &webConfig{ updater: upd, @@ -645,11 +650,12 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH confPath := configFilePath() - upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config) + updLogger := slogLogger.With(slogutil.KeyPrefix, "updater") + upd, isCustomURL := newUpdater(ctx, updLogger, config, globalContext.workDir, confPath, execPath) // TODO(e.burkov): This could be made earlier, probably as the option's // effect. - cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr) + cmdlineUpdate(ctx, updLogger, opts, upd, tlsMgr) if !globalContext.firstRun { // Save the updated config. @@ -671,7 +677,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH globalContext.auth, err = initUsers() fatalOnError(err) - web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL) + web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, isCustomURL) fatalOnError(err) globalContext.web = web @@ -714,16 +720,17 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH <-done } -// newUpdater creates a new AdGuard Home updater. customURL is true if the user -// has specified a custom version announcement URL. +// newUpdater creates a new AdGuard Home updater. l and conf must not be nil. +// workDir, confPath, and execPath must not be empty. isCustomURL is true if +// the user has specified a custom version announcement URL. func newUpdater( ctx context.Context, l *slog.Logger, + conf *configuration, workDir string, confPath string, execPath string, - config *configuration, -) (upd *updater.Updater, customURL bool) { +) (upd *updater.Updater, isCustomURL bool) { // envName is the name of the environment variable that can be used to // override the default version check URL. const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL" @@ -735,14 +742,14 @@ func newUpdater( case version.Channel() == version.ChannelRelease: // Only enable custom version URL for development builds. l.DebugContext(ctx, "custom version url is disabled for release builds") - case !config.UnsafeUseCustomUpdateIndexURL: + case !conf.UnsafeUseCustomUpdateIndexURL: l.DebugContext(ctx, "custom version url is disabled in config") default: versionURL, _ = url.Parse(customURLStr) } err := urlutil.ValidateHTTPURL(versionURL) - if customURL = err == nil; !customURL { + if isCustomURL = err == nil; !isCustomURL { l.DebugContext(ctx, "parsing custom version url", slogutil.KeyError, err) versionURL = updater.DefaultVersionURL() @@ -751,7 +758,8 @@ func newUpdater( l.DebugContext(ctx, "creating updater", "config_path", confPath) return updater.NewUpdater(&updater.Config{ - Client: config.Filtering.HTTPClient, + Client: conf.Filtering.HTTPClient, + Logger: l, Version: version.Version(), Channel: version.Channel(), GOARCH: runtime.GOARCH, @@ -762,7 +770,7 @@ func newUpdater( ConfName: confPath, ExecPath: execPath, VersionCheckURL: versionURL, - }), customURL + }), isCustomURL } // checkPermissions checks and migrates permissions of the files and directories @@ -1083,7 +1091,7 @@ func cmdlineUpdate( l.InfoContext(ctx, "performing update via cli") - info, err := upd.VersionInfo(true) + info, err := upd.VersionInfo(ctx, true) if err != nil { l.ErrorContext(ctx, "getting version info", slogutil.KeyError, err) @@ -1096,7 +1104,7 @@ func cmdlineUpdate( os.Exit(osutil.ExitCodeSuccess) } - err = upd.Update(globalContext.firstRun) + err = upd.Update(ctx, globalContext.firstRun) fatalOnError(err) err = restartService() diff --git a/internal/updater/check.go b/internal/updater/check.go index 2a3e2cfe..bd6325ab 100644 --- a/internal/updater/check.go +++ b/internal/updater/check.go @@ -1,6 +1,7 @@ package updater import ( + "context" "encoding/json" "fmt" "io" @@ -12,7 +13,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/ioutil" - "github.com/AdguardTeam/golibs/log" "github.com/c2h5oh/datasize" ) @@ -35,7 +35,7 @@ const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB // VersionInfo downloads the latest version information. If forceRecheck is // false and there are cached results, those results are returned. -func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { +func (u *Updater) VersionInfo(ctx context.Context, forceRecheck bool) (vi VersionInfo, err error) { u.mu.Lock() defer u.mu.Unlock() @@ -45,11 +45,17 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { return u.prevCheckResult, u.prevCheckError } - var resp *http.Response vcu := u.versionCheckURL - resp, err = u.client.Get(vcu) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, vcu, nil) if err != nil { - return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) + return VersionInfo{}, fmt.Errorf("constructing request to %s: %w", vcu, err) + } + + u.logger.DebugContext(ctx, "requesting version data", "url", vcu) + + resp, err := u.client.Do(req) + if err != nil { + return VersionInfo{}, fmt.Errorf("requesting %s: %w", vcu, err) } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() @@ -59,16 +65,16 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) { // ReadCloser. body, err := io.ReadAll(r) if err != nil { - return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err) + return VersionInfo{}, fmt.Errorf("reading response from %s: %w", vcu, err) } u.prevCheckTime = now - u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body) + u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(ctx, body) return u.prevCheckResult, u.prevCheckError } -func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) { +func (u *Updater) parseVersionResponse(ctx context.Context, data []byte) (VersionInfo, error) { info := VersionInfo{ CanAutoUpdate: aghalg.NBFalse, } @@ -92,7 +98,7 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) { info.Announcement = versionJSON["announcement"] info.AnnouncementURL = versionJSON["announcement_url"] - packageURL, key, found := u.downloadURL(versionJSON) + packageURL, key, found := u.downloadURL(ctx, versionJSON) if !found { return info, fmt.Errorf("version.json: no package URL: key %q not found in object", key) } @@ -108,7 +114,10 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) { // downloadURL returns the download URL for current build as well as its key in // versionObj. If the key is not found, it additionally prints an informative // log message. -func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string, ok bool) { +func (u *Updater) downloadURL( + ctx context.Context, + versionObj map[string]string, +) (dlURL, key string, ok bool) { if u.goarch == "arm" && u.goarm != "" { key = fmt.Sprintf("download_%s_%sv%s", u.goos, u.goarch, u.goarm) } else if isMIPS(u.goarch) && u.gomips != "" { @@ -124,7 +133,7 @@ func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string, keys := slices.Sorted(maps.Keys(versionObj)) - log.Error("updater: key %q not found; got keys %q", key, keys) + u.logger.ErrorContext(ctx, "key not found", "missing", key, "got", keys) return "", key, false } diff --git a/internal/updater/check_test.go b/internal/updater/check_test.go index 5a7c0f5d..4da1c876 100644 --- a/internal/updater/check_test.go +++ b/internal/updater/check_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,6 +59,7 @@ func TestUpdater_VersionInfo(t *testing.T) { u := updater.NewUpdater(&updater.Config{ Client: srv.Client(), + Logger: testLogger, Version: "v0.103.0-beta.1", Channel: version.ChannelBeta, GOARCH: "arm", @@ -65,7 +67,8 @@ func TestUpdater_VersionInfo(t *testing.T) { VersionCheckURL: fakeURL, }) - info, err := u.VersionInfo(false) + ctx := testutil.ContextWithTimeout(t, testTimeout) + info, err := u.VersionInfo(ctx, false) require.NoError(t, err) assert.Equal(t, counter, 1) @@ -75,14 +78,14 @@ func TestUpdater_VersionInfo(t *testing.T) { assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate) t.Run("cache_check", func(t *testing.T) { - _, err = u.VersionInfo(false) + _, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), false) require.NoError(t, err) assert.Equal(t, counter, 1) }) t.Run("force_check", func(t *testing.T) { - _, err = u.VersionInfo(true) + _, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true) require.NoError(t, err) assert.Equal(t, counter, 2) @@ -91,7 +94,7 @@ func TestUpdater_VersionInfo(t *testing.T) { t.Run("api_fail", func(t *testing.T) { srv.Close() - _, err = u.VersionInfo(true) + _, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true) var urlErr *url.Error assert.ErrorAs(t, err, &urlErr) }) @@ -130,6 +133,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) { for _, tc := range testCases { u := updater.NewUpdater(&updater.Config{ Client: fakeClient, + Logger: testLogger, Version: "v0.103.0-beta.1", Channel: version.ChannelBeta, GOOS: "linux", @@ -139,7 +143,8 @@ func TestUpdater_VersionInfo_others(t *testing.T) { VersionCheckURL: fakeURL, }) - info, err := u.VersionInfo(false) + ctx := testutil.ContextWithTimeout(t, testTimeout) + info, err := u.VersionInfo(ctx, false) require.NoError(t, err) assert.Equal(t, "v0.103.0-beta.2", info.NewVersion) diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 50869f35..93138230 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -5,9 +5,11 @@ import ( "archive/tar" "archive/zip" "compress/gzip" + "context" "fmt" "io" "io/fs" + "log/slog" "net/http" "net/url" "os" @@ -22,13 +24,14 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/ioutil" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" ) // Updater is the AdGuard Home updater. type Updater struct { client *http.Client + logger *slog.Logger version string channel string @@ -75,27 +78,48 @@ func DefaultVersionURL() *url.URL { // Config is the AdGuard Home updater configuration. type Config struct { + // Client is used to perform HTTP requests. It must not be nil. Client *http.Client + // Logger is used for logging the update process. It must not be nil. + Logger *slog.Logger + // VersionCheckURL is URL to the latest version announcement. It must not // be nil, see [DefaultVersionURL]. VersionCheckURL *url.URL + // Version is the current AdGuard Home version. It must not be empty. Version string - Channel string - GOARCH string - GOOS string - GOARM string - GOMIPS string - // ConfName is the name of the current configuration file. Typically, - // "AdGuardHome.yaml". + // Channel is the current AdGuard Home update channel. It must be a valid + // channel, see [version.ChannelBeta] and the related constants. + Channel string + + // GOARCH is the current CPU architecture. It must not be empty and must be + // one of the supported architectures. + GOARCH string + + // GOOS is the current operating system. It must not be empty and must be + // one of the supported OSs. + GOOS string + + // GOARM is the current ARM variant, if any. It must either be empty or be + // a valid and supported GOARM value. + GOARM string + + // GOMIPS is the current MIPS variant, if any. It must either be empty or + // be a valid and supported GOMIPS value. + GOMIPS string + + // ConfName is the name of the current configuration file. It must not be + // empty. ConfName string - // WorkDir is the working directory that is used for temporary files. + // WorkDir is the working directory that is used for temporary files. It + // must not be empty. WorkDir string - // ExecPath is path to the executable file. + // ExecPath is path to the executable file. It must not be empty. ExecPath string } @@ -103,6 +127,7 @@ type Config struct { func NewUpdater(conf *Config) *Updater { return &Updater{ client: conf.Client, + logger: conf.Logger, version: conf.Version, channel: conf.Channel, @@ -122,49 +147,49 @@ func NewUpdater(conf *Config) *Updater { // Update performs the auto-update. It returns an error if the update failed. // If firstRun is true, it assumes the configuration file doesn't exist. -func (u *Updater) Update(firstRun bool) (err error) { +func (u *Updater) Update(ctx context.Context, firstRun bool) (err error) { u.mu.Lock() defer u.mu.Unlock() - log.Info("updater: updating") + u.logger.InfoContext(ctx, "staring update", "first_run", firstRun) defer func() { if err != nil { - log.Info("updater: failed") + u.logger.ErrorContext(ctx, "update failed", slogutil.KeyError, err) } else { - log.Info("updater: finished successfully") + u.logger.InfoContext(ctx, "update finished") } }() - err = u.prepare() + err = u.prepare(ctx) if err != nil { return fmt.Errorf("preparing: %w", err) } - defer u.clean() + defer u.clean(ctx) - err = u.downloadPackageFile() + err = u.downloadPackageFile(ctx) if err != nil { return fmt.Errorf("downloading package file: %w", err) } - err = u.unpack() + err = u.unpack(ctx) if err != nil { return fmt.Errorf("unpacking: %w", err) } if !firstRun { - err = u.check() + err = u.check(ctx) if err != nil { return fmt.Errorf("checking config: %w", err) } } - err = u.backup(firstRun) + err = u.backup(ctx, firstRun) if err != nil { return fmt.Errorf("making backup: %w", err) } - err = u.replace() + err = u.replace(ctx) if err != nil { return fmt.Errorf("replacing: %w", err) } @@ -181,7 +206,7 @@ func (u *Updater) NewVersion() (nv string) { } // prepare fills all necessary fields in Updater object. -func (u *Updater) prepare() (err error) { +func (u *Updater) prepare(ctx context.Context) (err error) { u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion)) _, pkgNameOnly := filepath.Split(u.packageURL) @@ -200,11 +225,12 @@ func (u *Updater) prepare() (err error) { u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath)) u.updateExeName = filepath.Join(u.updateDir, updateExeName) - log.Debug( - "updater: updating from %s to %s using url: %s", - version.Version(), - u.newVersion, - u.packageURL, + u.logger.InfoContext( + ctx, + "updating", + "from", version.Version(), + "to", u.newVersion, + "package_url", u.packageURL, ) u.currentExeName = u.execPath @@ -217,23 +243,20 @@ func (u *Updater) prepare() (err error) { } // unpack extracts the files from the downloaded archive. -func (u *Updater) unpack() error { - var err error +func (u *Updater) unpack(ctx context.Context) (err error) { _, pkgNameOnly := filepath.Split(u.packageURL) - log.Debug("updater: unpacking package") + u.logger.InfoContext(ctx, "unpacking package", "package_name", pkgNameOnly) if strings.HasSuffix(pkgNameOnly, ".zip") { - u.unpackedFiles, err = zipFileUnpack(u.packageName, u.updateDir) + u.unpackedFiles, err = u.unpackZip(ctx, u.packageName, u.updateDir) if err != nil { return fmt.Errorf(".zip unpack failed: %w", err) } - } else if strings.HasSuffix(pkgNameOnly, ".tar.gz") { - u.unpackedFiles, err = tarGzFileUnpack(u.packageName, u.updateDir) + u.unpackedFiles, err = u.unpackTarGz(ctx, u.packageName, u.updateDir) if err != nil { return fmt.Errorf(".tar.gz unpack failed: %w", err) } - } else { return fmt.Errorf("unknown package extension") } @@ -243,8 +266,8 @@ func (u *Updater) unpack() error { // check returns an error if the configuration file couldn't be used with the // version of AdGuard Home just downloaded. -func (u *Updater) check() (err error) { - log.Debug("updater: checking configuration") +func (u *Updater) check(ctx context.Context) (err error) { + u.logger.InfoContext(ctx, "checking configuration") err = copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"), aghos.DefaultPermFile) if err != nil { @@ -268,8 +291,9 @@ func (u *Updater) check() (err error) { // backup makes a backup of the current configuration and supporting files. It // ignores the configuration file if firstRun is true. -func (u *Updater) backup(firstRun bool) (err error) { - log.Debug("updater: backing up current configuration") +func (u *Updater) backup(ctx context.Context, firstRun bool) (err error) { + u.logger.InfoContext(ctx, "backing up current configuration") + _ = os.Mkdir(u.backupDir, aghos.DefaultPermDir) if !firstRun { err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile) @@ -279,7 +303,7 @@ func (u *Updater) backup(firstRun bool) (err error) { } wd := u.workDir - err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir) + err = u.copySupportingFiles(ctx, u.unpackedFiles, wd, u.backupDir) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", wd, u.backupDir, err) } @@ -289,13 +313,18 @@ func (u *Updater) backup(firstRun bool) (err error) { // replace moves the current executable with the updated one and also copies the // supporting files. -func (u *Updater) replace() error { - err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir) +func (u *Updater) replace(ctx context.Context) (err error) { + err = u.copySupportingFiles(ctx, u.unpackedFiles, u.updateDir, u.workDir) if err != nil { return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", u.updateDir, u.workDir, err) } - log.Debug("updater: renaming: %s to %s", u.currentExeName, u.backupExeName) + u.logger.InfoContext( + ctx, + "backing up current executable", + "from", u.currentExeName, + "to", u.backupExeName, + ) err = os.Rename(u.currentExeName, u.backupExeName) if err != nil { return err @@ -311,14 +340,22 @@ func (u *Updater) replace() error { return err } - log.Debug("updater: renamed: %s to %s", u.updateExeName, u.currentExeName) + u.logger.InfoContext( + ctx, + "replacing current executable", + "from", u.updateExeName, + "to", u.currentExeName, + ) return nil } // clean removes the temporary directory itself and all it's contents. -func (u *Updater) clean() { - _ = os.RemoveAll(u.updateDir) +func (u *Updater) clean(ctx context.Context) { + err := os.RemoveAll(u.updateDir) + if err != nil { + u.logger.WarnContext(ctx, "removing update dir", slogutil.KeyError, err) + } } // MaxPackageFileSize is a maximum package file length in bytes. The largest @@ -327,34 +364,52 @@ func (u *Updater) clean() { const MaxPackageFileSize = 32 * 1024 * 1024 // Download package file and save it to disk -func (u *Updater) downloadPackageFile() (err error) { - var resp *http.Response - resp, err = u.client.Get(u.packageURL) +func (u *Updater) downloadPackageFile(ctx context.Context) (err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.packageURL, nil) if err != nil { - return fmt.Errorf("http request failed: %w", err) + return fmt.Errorf("constructing package request: %w", err) + } + + resp, err := u.client.Do(req) + if err != nil { + return fmt.Errorf("requesting package: %w", err) } defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }() r := ioutil.LimitReader(resp.Body, MaxPackageFileSize) - log.Debug("updater: reading http body") + u.logger.InfoContext(ctx, "reading http body") + // This use of ReadAll is now safe, because we limited body's Reader. body, err := io.ReadAll(r) if err != nil { return fmt.Errorf("io.ReadAll() failed: %w", err) } - _ = os.Mkdir(u.updateDir, aghos.DefaultPermDir) + err = os.Mkdir(u.updateDir, aghos.DefaultPermDir) + if err != nil { + // TODO(a.garipov): Consider returning this error. + u.logger.WarnContext(ctx, "creating update dir", slogutil.KeyError, err) + } + + u.logger.InfoContext(ctx, "saving package", "to", u.packageName) - log.Debug("updater: saving package to file") err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile) if err != nil { return fmt.Errorf("writing package file: %w", err) } + return nil } -func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) { +// unpackTarGzFile unpacks one file from a .tar.gz archive into outDir. All +// arguments must not be empty. +func (u *Updater) unpackTarGzFile( + ctx context.Context, + outDir string, + tr *tar.Reader, + hdr *tar.Header, +) (name string, err error) { name = filepath.Base(hdr.Name) if name == "" { return "", nil @@ -377,13 +432,18 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st return "", fmt.Errorf("creating directory %q: %w", outName, err) } - log.Debug("updater: created directory %q", outName) + u.logger.InfoContext(ctx, "created directory", "name", outName) return "", nil } if hdr.Typeflag != tar.TypeReg { - log.Info("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag) + u.logger.WarnContext( + ctx, + "unknown file type; skipping", + "file_name", name, + "type", hdr.Typeflag, + ) return "", nil } @@ -400,16 +460,19 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st return "", fmt.Errorf("io.Copy(): %w", err) } - log.Debug("updater: created file %q", outName) + u.logger.InfoContext(ctx, "created file", "name", outName) return name, nil } -// Unpack all files from .tar.gz file to the specified directory -// Existing files are overwritten -// All files are created inside outDir, subdirectories are not created -// Return the list of files (not directories) written -func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) { +// unpackTarGz unpack all files from a .tar.gz archive to outDir. Existing +// files are overwritten. All files are created inside outDir. files are the +// list of created files. +func (u *Updater) unpackTarGz( + ctx context.Context, + tarfile string, + outDir string, +) (files []string, err error) { f, err := os.Open(tarfile) if err != nil { return nil, fmt.Errorf("os.Open(): %w", err) @@ -437,7 +500,7 @@ func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) { } var name string - name, err = tarGzFileUnpackOne(outDir, tarReader, hdr) + name, err = u.unpackTarGzFile(ctx, outDir, tarReader, hdr) if name != "" { files = append(files, name) @@ -447,7 +510,13 @@ func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) { return files, err } -func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) { +// unpackZipFile unpacks one file from a .zip archive into outDir. All +// arguments must not be empty. +func (u *Updater) unpackZipFile( + ctx context.Context, + outDir string, + zf *zip.File, +) (name string, err error) { var rc io.ReadCloser rc, err = zf.Open() if err != nil { @@ -466,7 +535,8 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) { if name == "AdGuardHome" { // Top-level AdGuardHome/. Skip it. // - // TODO(a.garipov): See the similar todo in tarGzFileUnpack. + // TODO(a.garipov): See the similar TODO in + // [Updater.unpackTarGzFile]. return "", nil } @@ -475,7 +545,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) { return "", fmt.Errorf("creating directory %q: %w", outputName, err) } - log.Debug("updater: created directory %q", outputName) + u.logger.InfoContext(ctx, "created directory", "name", outputName) return "", nil } @@ -492,16 +562,19 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) { return "", fmt.Errorf("io.Copy(): %w", err) } - log.Debug("updater: created file %q", outputName) + u.logger.InfoContext(ctx, "created file", "name", outputName) return name, nil } -// Unpack all files from .zip file to the specified directory -// Existing files are overwritten -// All files are created inside 'outDir', subdirectories are not created -// Return the list of files (not directories) written -func zipFileUnpack(zipfile, outDir string) (files []string, err error) { +// unpackZip unpack all files from a .zip archive to outDir. Existing files are +// overwritten. All files are created inside outDir. files are the list of +// created files. +func (u *Updater) unpackZip( + ctx context.Context, + zipfile string, + outDir string, +) (files []string, err error) { zrc, err := zip.OpenReader(zipfile) if err != nil { return nil, fmt.Errorf("zip.OpenReader(): %w", err) @@ -510,7 +583,7 @@ func zipFileUnpack(zipfile, outDir string) (files []string, err error) { for _, zf := range zrc.File { var name string - name, err = zipFileUnpackOne(outDir, zf) + name, err = u.unpackZipFile(ctx, outDir, zf) if err != nil { break } @@ -543,7 +616,12 @@ func copyFile(src, dst string, perm fs.FileMode) (err error) { // copySupportingFiles copies each file specified in files from srcdir to // dstdir. If a file specified as a path, only the name of the file is used. // It skips AdGuardHome, AdGuardHome.exe, and AdGuardHome.yaml. -func copySupportingFiles(files []string, srcdir, dstdir string) error { +func (u *Updater) copySupportingFiles( + ctx context.Context, + files []string, + srcdir string, + dstdir string, +) (err error) { for _, f := range files { _, name := filepath.Split(f) if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { @@ -553,12 +631,12 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error { src := filepath.Join(srcdir, name) dst := filepath.Join(dstdir, name) - err := copyFile(src, dst, aghos.DefaultPermFile) + err = copyFile(src, dst, aghos.DefaultPermFile) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } - log.Debug("updater: copied: %q to %q", src, dst) + u.logger.InfoContext(ctx, "copied", "from", src, "to", dst) } return nil diff --git a/internal/updater/updater_internal_test.go b/internal/updater/updater_internal_test.go index 67c16dc1..3d96a8ff 100644 --- a/internal/updater/updater_internal_test.go +++ b/internal/updater/updater_internal_test.go @@ -1,12 +1,16 @@ package updater import ( + "context" "net/url" "os" "path/filepath" "testing" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -55,6 +59,7 @@ func TestUpdater_internal(t *testing.T) { u := NewUpdater(&Config{ Client: fakeClient, + Logger: slogutil.NewDiscardLogger(), GOOS: tc.os, Version: "v0.103.0", ExecPath: exePath, @@ -68,13 +73,13 @@ func TestUpdater_internal(t *testing.T) { u.newVersion = "v0.103.1" u.packageURL = fakeURL.String() - require.NoError(t, u.prepare()) - require.NoError(t, u.downloadPackageFile()) - require.NoError(t, u.unpack()) - require.NoError(t, u.backup(false)) - require.NoError(t, u.replace()) + require.NoError(t, u.prepare(newCtx(t))) + require.NoError(t, u.downloadPackageFile(newCtx(t))) + require.NoError(t, u.unpack(newCtx(t))) + require.NoError(t, u.backup(newCtx(t), false)) + require.NoError(t, u.replace(newCtx(t))) - u.clean() + u.clean(newCtx(t)) require.True(t, t.Run("backup", func(t *testing.T) { var d []byte @@ -113,3 +118,8 @@ func TestUpdater_internal(t *testing.T) { })) } } + +// newCtx is a helper that returns a new context with a timeout. +func newCtx(tb testing.TB) (ctx context.Context) { + return testutil.ContextWithTimeout(tb, 1*time.Second) +} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 735d9c99..dfef0b10 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -10,17 +10,21 @@ import ( "path/filepath" "runtime" "testing" + "time" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} +// testTimeout is the common timeout for tests. +const testTimeout = 1 * time.Second + +// testLogger is the common logger for tests. +var testLogger = slogutil.NewDiscardLogger() func TestUpdater_Update(t *testing.T) { const jsonData = `{ @@ -73,6 +77,7 @@ func TestUpdater_Update(t *testing.T) { u := updater.NewUpdater(&updater.Config{ Client: srv.Client(), + Logger: testLogger, GOARCH: "amd64", GOOS: "linux", Version: "v0.103.0", @@ -82,10 +87,12 @@ func TestUpdater_Update(t *testing.T) { VersionCheckURL: versionCheckURL, }) - _, err = u.VersionInfo(false) + ctx := testutil.ContextWithTimeout(t, testTimeout) + _, err = u.VersionInfo(ctx, false) require.NoError(t, err) - err = u.Update(true) + ctx = testutil.ContextWithTimeout(t, testTimeout) + err = u.Update(ctx, true) require.NoError(t, err) // check backup files @@ -124,14 +131,15 @@ func TestUpdater_Update(t *testing.T) { t.Skip("skipping config check test on windows") } - err = u.Update(false) + err = u.Update(testutil.ContextWithTimeout(t, testTimeout), false) assert.NoError(t, err) }) t.Run("api_fail", func(t *testing.T) { srv.Close() - err = u.Update(true) + err = u.Update(testutil.ContextWithTimeout(t, testTimeout), true) + var urlErr *url.Error assert.ErrorAs(t, err, &urlErr) })