all: sync with master
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/exp/maps"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
@@ -28,8 +29,9 @@ type VersionInfo struct {
|
||||
CanAutoUpdate aghalg.NullBool `json:"can_autoupdate,omitempty"`
|
||||
}
|
||||
|
||||
// MaxResponseSize is responses on server's requests maximum length in bytes.
|
||||
const MaxResponseSize = 64 * 1024
|
||||
// maxVersionRespSize is the maximum length in bytes for version information
|
||||
// response.
|
||||
const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB
|
||||
|
||||
// VersionInfo downloads the latest version information. If forceRecheck is
|
||||
// false and there are cached results, those results are returned.
|
||||
@@ -51,7 +53,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
r := ioutil.LimitReader(resp.Body, MaxResponseSize)
|
||||
r := ioutil.LimitReader(resp.Body, maxVersionRespSize.Bytes())
|
||||
|
||||
// This use of ReadAll is safe, because we just limited the appropriate
|
||||
// ReadCloser.
|
||||
@@ -120,8 +122,8 @@ func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string,
|
||||
return dlURL, key, true
|
||||
}
|
||||
|
||||
keys := maps.Keys(versionObj)
|
||||
slices.Sort(keys)
|
||||
keys := slices.Sorted(maps.Keys(versionObj))
|
||||
|
||||
log.Error("updater: key %q not found; got keys %q", key, keys)
|
||||
|
||||
return "", key, false
|
||||
|
||||
@@ -51,9 +51,11 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
fakeURL, err := url.JoinPath(srv.URL, "adguardhome", version.ChannelBeta, "version.json")
|
||||
srvURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Version: "v0.103.0-beta.1",
|
||||
@@ -134,7 +136,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
GOARCH: tc.arch,
|
||||
GOARM: tc.arm,
|
||||
GOMIPS: tc.mips,
|
||||
VersionCheckURL: fakeURL.String(),
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// Updater is the AdGuard Home updater.
|
||||
@@ -61,10 +64,23 @@ type Updater struct {
|
||||
prevCheckResult VersionInfo
|
||||
}
|
||||
|
||||
// DefaultVersionURL returns the default URL for the version announcement.
|
||||
func DefaultVersionURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: urlutil.SchemeHTTPS,
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
||||
}
|
||||
}
|
||||
|
||||
// Config is the AdGuard Home updater configuration.
|
||||
type Config struct {
|
||||
Client *http.Client
|
||||
|
||||
// VersionCheckURL is URL to the latest version announcement. It must not
|
||||
// be nil, see [DefaultVersionURL].
|
||||
VersionCheckURL *url.URL
|
||||
|
||||
Version string
|
||||
Channel string
|
||||
GOARCH string
|
||||
@@ -81,12 +97,9 @@ type Config struct {
|
||||
|
||||
// ExecPath is path to the executable file.
|
||||
ExecPath string
|
||||
|
||||
// VersionCheckURL is url to the latest version announcement.
|
||||
VersionCheckURL string
|
||||
}
|
||||
|
||||
// NewUpdater creates a new Updater.
|
||||
// NewUpdater creates a new Updater. conf must not be nil.
|
||||
func NewUpdater(conf *Config) *Updater {
|
||||
return &Updater{
|
||||
client: conf.Client,
|
||||
@@ -101,7 +114,7 @@ func NewUpdater(conf *Config) *Updater {
|
||||
confName: conf.ConfName,
|
||||
workDir: conf.WorkDir,
|
||||
execPath: conf.ExecPath,
|
||||
versionCheckURL: conf.VersionCheckURL,
|
||||
versionCheckURL: conf.VersionCheckURL.String(),
|
||||
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
@@ -167,14 +180,6 @@ func (u *Updater) NewVersion() (nv string) {
|
||||
return u.newVersion
|
||||
}
|
||||
|
||||
// VersionCheckURL returns the version check URL.
|
||||
func (u *Updater) VersionCheckURL() (vcu string) {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
return u.versionCheckURL
|
||||
}
|
||||
|
||||
// prepare fills all necessary fields in Updater object.
|
||||
func (u *Updater) prepare() (err error) {
|
||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||
@@ -265,7 +270,7 @@ func (u *Updater) check() (err error) {
|
||||
// ignores the configuration file if firstRun is true.
|
||||
func (u *Updater) backup(firstRun bool) (err error) {
|
||||
log.Debug("updater: backing up current configuration")
|
||||
_ = aghos.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||
if !firstRun {
|
||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
@@ -339,10 +344,10 @@ func (u *Updater) downloadPackageFile() (err error) {
|
||||
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
||||
}
|
||||
|
||||
_ = aghos.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
|
||||
log.Debug("updater: saving package to file")
|
||||
err = aghos.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing package file: %w", err)
|
||||
}
|
||||
@@ -355,7 +360,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outDir, name)
|
||||
outName := filepath.Join(outDir, name)
|
||||
|
||||
if hdr.Typeflag == tar.TypeDir {
|
||||
if name == "AdGuardHome" {
|
||||
@@ -367,12 +372,12 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = aghos.Mkdir(outputName, os.FileMode(hdr.Mode&0o755))
|
||||
err = os.Mkdir(outName, os.FileMode(hdr.Mode&0o755))
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||
return "", fmt.Errorf("creating directory %q: %w", outName, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
log.Debug("updater: created directory %q", outName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -384,13 +389,9 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = aghos.OpenFile(
|
||||
outputName,
|
||||
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
|
||||
os.FileMode(hdr.Mode&0o755),
|
||||
)
|
||||
wc, err = os.OpenFile(outName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(hdr.Mode)&0o755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
|
||||
return "", fmt.Errorf("os.OpenFile(%s): %w", outName, err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
|
||||
|
||||
@@ -399,7 +400,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created file %q", outputName)
|
||||
log.Debug("updater: created file %q", outName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
@@ -469,7 +470,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = aghos.Mkdir(outputName, fi.Mode())
|
||||
err = os.Mkdir(outputName, fi.Mode())
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||
}
|
||||
@@ -480,7 +481,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = aghos.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(): %w", err)
|
||||
}
|
||||
@@ -530,7 +531,7 @@ func copyFile(src, dst string, perm fs.FileMode) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = aghos.WriteFile(dst, d, perm)
|
||||
err = os.WriteFile(dst, d, perm)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -45,7 +46,7 @@ func TestUpdater_internal(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
exePath := filepath.Join(wd, tc.exeName)
|
||||
|
||||
// start server for returning package file
|
||||
// Start server for returning package file.
|
||||
pkgData, err := os.ReadFile(filepath.Join("testdata", tc.archiveName))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -59,6 +60,9 @@ func TestUpdater_internal(t *testing.T) {
|
||||
ExecPath: exePath,
|
||||
WorkDir: wd,
|
||||
ConfName: yamlPath,
|
||||
// TODO(e.burkov): Rewrite the test to use a fake version check
|
||||
// URL with a fake URLs for the package files.
|
||||
VersionCheckURL: &url.URL{},
|
||||
})
|
||||
|
||||
u.newVersion = "v0.103.1"
|
||||
@@ -72,36 +76,40 @@ func TestUpdater_internal(t *testing.T) {
|
||||
|
||||
u.clean()
|
||||
|
||||
// check backup files
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
require.True(t, t.Run("backup", func(t *testing.T) {
|
||||
var d []byte
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.exeName, string(d))
|
||||
assert.Equal(t, tc.exeName, string(d))
|
||||
}))
|
||||
|
||||
// check updated files
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
require.True(t, t.Run("updated", func(t *testing.T) {
|
||||
var d []byte
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", string(d))
|
||||
assert.Equal(t, "1", string(d))
|
||||
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2", string(d))
|
||||
assert.Equal(t, "2", string(d))
|
||||
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "3", string(d))
|
||||
assert.Equal(t, "3", string(d))
|
||||
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +65,10 @@ func TestUpdater_Update(t *testing.T) {
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
versionCheckURL, err := url.JoinPath(srv.URL, versionPath)
|
||||
srvURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
versionCheckURL := srvURL.JoinPath(versionPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
|
||||
Reference in New Issue
Block a user