cherry-pick: home: rm unnecessary locking in update; refactor

Merge in DNS/adguard-home from 4499-rm-unnecessary-locking to master

Squashed commit of the following:

commit 6d70472506dd0fd69225454c73d9f7f6a208b76b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Apr 25 17:26:54 2022 +0300

    home: rm unnecessary locking in update; refactor
This commit is contained in:
Ainar Garipov
2022-04-25 18:41:39 +03:00
committed by Ainar Garipov
parent 1547f9d35e
commit 2898a49d86
8 changed files with 82 additions and 71 deletions

View File

@@ -0,0 +1,10 @@
//go:build darwin || freebsd || openbsd
// +build darwin freebsd openbsd
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}

View File

@@ -23,10 +23,6 @@ type hardwarePortInfo struct {
static bool static bool
} }
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (bool, error) { func ifaceHasStaticIP(ifaceName string) (bool, error) {
portInfo, err := getCurrentHardwarePortInfo(ifaceName) portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil { if err != nil {

View File

@@ -13,10 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
) )
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const rcConfFilename = "etc/rc.conf" const rcConfFilename = "etc/rc.conf"

View File

@@ -13,10 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
) )
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("etc/hostname.%s", ifaceName) filename := fmt.Sprintf("etc/hostname.%s", ifaceName)

View File

@@ -1,5 +1,5 @@
//go:build !(linux || darwin || freebsd || openbsd) //go:build windows
// +build !linux,!darwin,!freebsd,!openbsd // +build windows
package aghnet package aghnet
@@ -14,7 +14,7 @@ import (
) )
func canBindPrivilegedPorts() (can bool, err error) { func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights() return true, nil
} }
func ifaceHasStaticIP(string) (ok bool, err error) { func ifaceHasStaticIP(string) (ok bool, err error) {

View File

@@ -3,6 +3,7 @@ package home
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@@ -27,12 +28,16 @@ type temporaryError interface {
// Get the latest available version from the Internet // Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
resp := &versionResponse{} resp := &versionResponse{}
if Context.disableUpdate { if Context.disableUpdate {
// w.Header().Set("Content-Type", "application/json")
resp.Disabled = true resp.Disabled = true
_ = json.NewEncoder(w).Encode(resp) err := json.NewEncoder(w).Encode(resp)
// TODO(e.burkov): Add error handling and deal with headers. if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
return return
} }
@@ -44,30 +49,48 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 { if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req) err = json.NewDecoder(r.Body).Decode(req)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "parsing request: %s", err)
return return
} }
} }
err = requestVersionInfo(resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
return
}
err = resp.setAllowedToAutoUpdate()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
// update server.
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
for i := 0; i != 3; i++ { for i := 0; i != 3; i++ {
func() { resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
}()
if err != nil { if err != nil {
var terr temporaryError var terr temporaryError
if errors.As(err, &terr) && terr.Temporary() { if errors.As(err, &terr) && terr.Temporary() {
// Temporary network error. This case may happen while // Temporary network error. This case may happen while we're
// we're restarting our DNS server. Log and sleep for // restarting our DNS server. Log and sleep for some time.
// some time.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/934. // See https://github.com/AdguardTeam/AdGuardHome/issues/934.
d := time.Duration(i) * time.Second d := time.Duration(i) * time.Second
log.Info("temp net error: %q; sleeping for %s and retrying", err, d) log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
time.Sleep(d) time.Sleep(d)
continue continue
@@ -76,29 +99,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
break break
} }
if err != nil { if err != nil {
vcu := Context.updater.VersionCheckURL() vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb.
aghhttp.Error(
r,
w,
http.StatusBadGateway,
"Couldn't get version check json from %s: %T %s\n",
vcu,
err,
err,
)
return return fmt.Errorf("getting version info from %s: %s", vcu, err)
} }
resp.confirmAutoUpdate() return nil
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
} }
// handleUpdate performs an update to the latest available version procedure. // handleUpdate performs an update to the latest available version procedure.
@@ -132,31 +140,37 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
// versionResponse is the response for /control/version.json endpoint. // versionResponse is the response for /control/version.json endpoint.
type versionResponse struct { type versionResponse struct {
Disabled bool `json:"disabled"`
updater.VersionInfo updater.VersionInfo
Disabled bool `json:"disabled"`
} }
// confirmAutoUpdate checks the real possibility of auto update. // setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
func (vr *versionResponse) confirmAutoUpdate() { // allowed to perform an automatic update by the OS.
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate { func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
canUpdate := true if vr.CanAutoUpdate == nil || !*vr.CanAutoUpdate {
return
var tlsConf *tlsConfigSettings
if runtime.GOOS != "windows" {
tlsConf = &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
}
if tlsConf != nil &&
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
tlsConf.PortDNSOverTLS < 1024 ||
tlsConf.PortDNSOverQUIC < 1024)) ||
config.BindPort < 1024 ||
config.DNS.Port < 1024) {
canUpdate, _ = aghnet.CanBindPrivilegedPorts()
}
vr.CanAutoUpdate = &canUpdate
} }
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
}
}
vr.CanAutoUpdate = &canUpdate
return nil
}
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
} }
// finishUpdate completes an update procedure. // finishUpdate completes an update procedure.

View File

@@ -17,11 +17,11 @@ const versionCheckPeriod = 8 * time.Hour
// VersionInfo contains information about a new version. // VersionInfo contains information about a new version.
type VersionInfo struct { type VersionInfo struct {
CanAutoUpdate *bool `json:"can_autoupdate,omitempty"`
NewVersion string `json:"new_version,omitempty"` NewVersion string `json:"new_version,omitempty"`
Announcement string `json:"announcement,omitempty"` Announcement string `json:"announcement,omitempty"`
AnnouncementURL string `json:"announcement_url,omitempty"` AnnouncementURL string `json:"announcement_url,omitempty"`
SelfUpdateMinVersion string `json:"-"` SelfUpdateMinVersion string `json:"-"`
CanAutoUpdate *bool `json:"can_autoupdate,omitempty"`
} }
// MaxResponseSize is responses on server's requests maximum length in bytes. // MaxResponseSize is responses on server's requests maximum length in bytes.

View File

@@ -136,7 +136,6 @@ underscores() {
-e '_freebsd.go'\ -e '_freebsd.go'\
-e '_linux.go'\ -e '_linux.go'\
-e '_little.go'\ -e '_little.go'\
-e '_nolinux.go'\
-e '_openbsd.go'\ -e '_openbsd.go'\
-e '_others.go'\ -e '_others.go'\
-e '_test.go'\ -e '_test.go'\