all: sync with master; upd chlog
This commit is contained in:
@@ -530,14 +530,14 @@ func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP)
|
||||
// prepareInternalProxy initializes the DNS proxy that is used for internal DNS
|
||||
// queries, such as public clients PTR resolving and updater hostname resolving.
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
srvConf := s.conf
|
||||
conf := &proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: int(s.conf.MaxGoroutines),
|
||||
}
|
||||
|
||||
srvConf := s.conf
|
||||
setProxyUpstreamMode(
|
||||
conf,
|
||||
srvConf.AllServers,
|
||||
|
||||
@@ -2,6 +2,7 @@ package filtering
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -97,14 +99,15 @@ func (d *DNSFilter) filterSetProperties(
|
||||
filt.URL,
|
||||
)
|
||||
|
||||
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time) {
|
||||
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time, oldRulesCount int) {
|
||||
if err != nil {
|
||||
filt.URL = oldURL
|
||||
filt.Name = oldName
|
||||
filt.Enabled = oldEnabled
|
||||
filt.LastUpdated = oldUpdated
|
||||
filt.RulesCount = oldRulesCount
|
||||
}
|
||||
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated)
|
||||
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated, filt.RulesCount)
|
||||
|
||||
filt.Name = newList.Name
|
||||
|
||||
@@ -134,8 +137,8 @@ func (d *DNSFilter) filterSetProperties(
|
||||
// TODO(e.burkov): The validation of the contents of the new URL is
|
||||
// currently skipped if the rule list is disabled. This makes it
|
||||
// possible to set a bad rules source, but the validation should still
|
||||
// kick in when the filter is enabled. Consider making changing this
|
||||
// behavior to be stricter.
|
||||
// kick in when the filter is enabled. Consider changing this behavior
|
||||
// to be stricter.
|
||||
filt.unload()
|
||||
}
|
||||
|
||||
@@ -269,10 +272,10 @@ func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
// already going on.
|
||||
//
|
||||
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
|
||||
// sync.Mutex.TryLock.
|
||||
// [sync.Mutex.TryLock].
|
||||
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
|
||||
if ok = d.refreshLock.TryLock(); !ok {
|
||||
return 0, false, ok
|
||||
return 0, false, false
|
||||
}
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
@@ -427,52 +430,124 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||
return updNum, false
|
||||
}
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
func isPrintableText(data []byte, len int) bool {
|
||||
for i := 0; i < len; i++ {
|
||||
c := data[i]
|
||||
// isPrintableText returns true if data is printable UTF-8 text with CR, LF, TAB
|
||||
// characters.
|
||||
//
|
||||
// TODO(e.burkov): Investigate the purpose of this and improve the
|
||||
// implementation. Perhaps, use something from the unicode package.
|
||||
func isPrintableText(data string) (ok bool) {
|
||||
for _, c := range []byte(data) {
|
||||
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
r := bufio.NewReader(file)
|
||||
checksum := uint32(0)
|
||||
// scanLinesWithBreak is essentially a [bufio.ScanLines] which keeps trailing
|
||||
// line breaks.
|
||||
func scanLinesWithBreak(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
||||
return i + 1, data[0 : i+1], nil
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
//
|
||||
} else if line[0] == '!' {
|
||||
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
} else if line[0] == '#' {
|
||||
//
|
||||
} else {
|
||||
rulesCount++
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// parseFilter copies filter's content from src to dst and returns the number of
|
||||
// rules, name, number of bytes written, checksum, and title of the parsed list.
|
||||
// dst must not be nil.
|
||||
func (d *DNSFilter) parseFilter(
|
||||
src io.Reader,
|
||||
dst io.Writer,
|
||||
) (rulesNum, written int, checksum uint32, title string, err error) {
|
||||
scanner := bufio.NewScanner(src)
|
||||
scanner.Split(scanLinesWithBreak)
|
||||
|
||||
titleFound := false
|
||||
for n := 0; scanner.Scan(); written += n {
|
||||
line := scanner.Text()
|
||||
var isRule bool
|
||||
var likelyTitle string
|
||||
isRule, likelyTitle, err = d.parseFilterLine(line, !titleFound, written == 0)
|
||||
if err != nil {
|
||||
return 0, written, 0, "", err
|
||||
}
|
||||
|
||||
if isRule {
|
||||
rulesNum++
|
||||
} else if likelyTitle != "" {
|
||||
title, titleFound = likelyTitle, true
|
||||
}
|
||||
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
|
||||
n, err = dst.Write([]byte(line))
|
||||
if err != nil {
|
||||
break
|
||||
return 0, written, 0, "", fmt.Errorf("writing filter line: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, checksum, name
|
||||
if err = scanner.Err(); err != nil {
|
||||
return 0, written, 0, "", fmt.Errorf("scanning filter contents: %w", err)
|
||||
}
|
||||
|
||||
return rulesNum, written, checksum, title, nil
|
||||
}
|
||||
|
||||
// parseFilterLine returns true if the passed line is a rule. line is
|
||||
// considered a rule if it's not a comment and contains no title.
|
||||
func (d *DNSFilter) parseFilterLine(
|
||||
line string,
|
||||
lookForTitle bool,
|
||||
testHTML bool,
|
||||
) (isRule bool, title string, err error) {
|
||||
if !isPrintableText(line) {
|
||||
return false, "", errors.Error("filter contains non-printable characters")
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || line[0] == '#' {
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
if testHTML && isHTML(line) {
|
||||
return false, "", errors.Error("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
if line[0] == '!' && lookForTitle {
|
||||
match := d.filterTitleRegexp.FindStringSubmatch(line)
|
||||
if len(match) > 1 {
|
||||
title = match[1]
|
||||
}
|
||||
|
||||
return false, title, nil
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
// isHTML returns true if the line contains HTML tags instead of plain text.
|
||||
// line shouldn have no leading space symbols.
|
||||
//
|
||||
// TODO(ameshkov): It actually gives too much false-positives. Perhaps, just
|
||||
// check if trimmed string begins with angle bracket.
|
||||
func isHTML(line string) (ok bool) {
|
||||
line = strings.ToLower(line)
|
||||
|
||||
return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype")
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
@@ -485,57 +560,10 @@ func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
}
|
||||
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
|
||||
htmlTest := true
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
buf := make([]byte, 64*1024)
|
||||
total := 0
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
total += n
|
||||
|
||||
if htmlTest {
|
||||
num := len(firstChunk) - firstChunkLen
|
||||
if n < num {
|
||||
num = n
|
||||
}
|
||||
copied := copy(firstChunk[firstChunkLen:], buf[:num])
|
||||
firstChunkLen += copied
|
||||
|
||||
if firstChunkLen == len(firstChunk) || err == io.EOF {
|
||||
if !isPrintableText(firstChunk, firstChunkLen) {
|
||||
return total, fmt.Errorf("data contains non-printable characters")
|
||||
}
|
||||
|
||||
s := strings.ToLower(string(firstChunk))
|
||||
if strings.Contains(s, "<html") || strings.Contains(s, "<!doctype") {
|
||||
return total, fmt.Errorf("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
htmlTest = false
|
||||
firstChunk = nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err2 := tmpFile.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
return total, err2
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if sucсeeded.
|
||||
@@ -552,7 +580,8 @@ func (d *DNSFilter) finalizeUpdate(
|
||||
// Close the file before renaming it because it's required on Windows.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||
if err = file.Close(); err != nil {
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing temporary file: %w", err)
|
||||
}
|
||||
|
||||
@@ -564,38 +593,18 @@ func (d *DNSFilter) finalizeUpdate(
|
||||
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||
|
||||
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
|
||||
// Don't use renamio or maybe packages, since those will require loading the
|
||||
// whole filter content to the memory on Windows.
|
||||
err = os.Rename(tmpFileName, flt.Path(d.DataDir))
|
||||
if err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
|
||||
flt.Name = stringutil.Coalesce(flt.Name, name)
|
||||
flt.checksum = cs
|
||||
flt.RulesCount = rnum
|
||||
flt.Name, flt.checksum, flt.RulesCount = aghalg.Coalesce(flt.Name, name), cs, rnum
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUpdate copies filter's content from src to dst and returns the name,
|
||||
// rules number, and checksum for it. It also returns the number of bytes read
|
||||
// from src.
|
||||
func (d *DNSFilter) processUpdate(
|
||||
src io.Reader,
|
||||
dst *os.File,
|
||||
flt *FilterYAML,
|
||||
) (name string, rnum int, cs uint32, n int, err error) {
|
||||
if n, err = d.read(src, dst, flt); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
if _, err = dst.Seek(0, io.SeekStart); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
rnum, cs, name = d.parseFilterContents(dst)
|
||||
|
||||
return name, rnum, cs, n, nil
|
||||
}
|
||||
|
||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||
// the actual update has been performed.
|
||||
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
@@ -612,31 +621,21 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
ok = ok && err == nil
|
||||
if ok {
|
||||
if ok && err == nil {
|
||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||
}
|
||||
}()
|
||||
|
||||
// Change the default 0o600 permission to something more acceptable by
|
||||
// end users.
|
||||
// Change the default 0o600 permission to something more acceptable by end
|
||||
// users.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3198.
|
||||
if err = tmpFile.Chmod(0o644); err != nil {
|
||||
return false, fmt.Errorf("changing file mode: %w", err)
|
||||
}
|
||||
|
||||
var r io.Reader
|
||||
if filepath.IsAbs(flt.URL) {
|
||||
var file io.ReadCloser
|
||||
file, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, file.Close()) }()
|
||||
|
||||
r = file
|
||||
} else {
|
||||
var rc io.ReadCloser
|
||||
if !filepath.IsAbs(flt.URL) {
|
||||
var resp *http.Response
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
if err != nil {
|
||||
@@ -649,24 +648,30 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("got status code %d from %s, skip", resp.StatusCode, flt.URL)
|
||||
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
r = resp.Body
|
||||
rc = resp.Body
|
||||
} else {
|
||||
rc, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
|
||||
}
|
||||
|
||||
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
|
||||
rnum, n, cs, name, err = d.parseFilter(rc, tmpFile)
|
||||
|
||||
return cs != flt.checksum, err
|
||||
return cs != flt.checksum && err == nil, err
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
filterFilePath := filter.Path(d.DataDir)
|
||||
func (d *DNSFilter) load(flt *FilterYAML) (err error) {
|
||||
fileName := flt.Path(d.DataDir)
|
||||
|
||||
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
|
||||
log.Debug("filtering: loading filter %d from %s", flt.ID, fileName)
|
||||
|
||||
file, err := os.Open(filterFilePath)
|
||||
file, err := os.Open(fileName)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Do nothing, file doesn't exist.
|
||||
return nil
|
||||
@@ -680,13 +685,14 @@ func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
return fmt.Errorf("getting filter file stat: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
||||
log.Debug("filtering: file %s, id %d, length %d", fileName, flt.ID, st.Size())
|
||||
|
||||
rulesCount, checksum, _ := d.parseFilterContents(file)
|
||||
rulesCount, _, checksum, _, err := d.parseFilter(file, io.Discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing filter file: %w", err)
|
||||
}
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
filter.LastUpdated = st.ModTime()
|
||||
flt.RulesCount, flt.checksum, flt.LastUpdated = rulesCount, checksum, st.ModTime()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,33 +4,23 @@ import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// serveFiltersLocally is a helper that concurrently listens on a free port to
|
||||
// respond with fltContent. It also gracefully closes the listener when the
|
||||
// test under t finishes.
|
||||
func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||
// serveHTTPLocally starts a new HTTP server, that handles its index with h. It
|
||||
// also gracefully closes the listener when the test under t finishes.
|
||||
func serveHTTPLocally(t *testing.T, h http.Handler) (urlStr string) {
|
||||
t.Helper()
|
||||
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
n, werr := w.Write(fltContent)
|
||||
require.NoError(pt, werr)
|
||||
require.Equal(pt, len(fltContent), n)
|
||||
})
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -38,9 +28,26 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
addr := l.Addr()
|
||||
require.IsType(t, new(net.TCPAddr), addr)
|
||||
require.IsType(t, (*net.TCPAddr)(nil), addr)
|
||||
|
||||
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
||||
return (&url.URL{
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
// serveFiltersLocally is a helper that concurrently listens on a free port to
|
||||
// respond with fltContent.
|
||||
func serveFiltersLocally(t *testing.T, fltContent []byte) (urlStr string) {
|
||||
t.Helper()
|
||||
|
||||
return serveHTTPLocally(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
n, werr := w.Write(fltContent)
|
||||
require.NoError(pt, werr)
|
||||
require.Equal(pt, len(fltContent), n)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
@@ -65,10 +72,7 @@ func TestFilters(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &FilterYAML{
|
||||
URL: (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
}).String(),
|
||||
URL: addr,
|
||||
}
|
||||
|
||||
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
||||
@@ -103,11 +107,7 @@ func TestFilters(t *testing.T) {
|
||||
anotherContent := []byte(`||example.com^`)
|
||||
oldURL := f.URL
|
||||
|
||||
ipp := serveFiltersLocally(t, anotherContent)
|
||||
f.URL = (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: ipp.String(),
|
||||
}).String()
|
||||
f.URL = serveFiltersLocally(t, anotherContent)
|
||||
t.Cleanup(func() { f.URL = oldURL })
|
||||
|
||||
updateAndAssert(t, require.True, 1)
|
||||
|
||||
@@ -190,6 +190,8 @@ type DNSFilter struct {
|
||||
|
||||
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||
// filter list.
|
||||
//
|
||||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
hostCheckers []hostChecker
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -30,11 +29,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
endpoint: &badRulesEndpoint,
|
||||
content: []byte(`<html></html>`),
|
||||
}} {
|
||||
ipp := serveFiltersLocally(t, rulesSource.content)
|
||||
*rulesSource.endpoint = (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: ipp.String(),
|
||||
}).String()
|
||||
*rulesSource.endpoint = serveFiltersLocally(t, rulesSource.content)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -106,6 +106,8 @@ type configuration struct {
|
||||
ProxyURL string `yaml:"http_proxy"`
|
||||
// Language is a two-letter ISO 639-1 language code.
|
||||
Language string `yaml:"language"`
|
||||
// Theme is a UI theme for current user.
|
||||
Theme Theme `yaml:"theme"`
|
||||
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
|
||||
DebugPProf bool `yaml:"debug_pprof"`
|
||||
|
||||
@@ -322,6 +324,7 @@ var config = &configuration{
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
Theme: ThemeAuto,
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
|
||||
@@ -149,19 +149,6 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
type profileJSON struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
resp := &profileJSON{
|
||||
Name: u.Name,
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
@@ -172,6 +159,7 @@ func registerControlHandlers() {
|
||||
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
|
||||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||
|
||||
// No auth is necessary for DoH/DoT configurations
|
||||
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
||||
|
||||
@@ -123,7 +123,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
err = Context.updater.Update(false)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -39,17 +41,13 @@ func onConfigModified() {
|
||||
}
|
||||
}
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer() (err error) {
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized.
|
||||
func initDNS() (err error) {
|
||||
baseDir := Context.getDataDir()
|
||||
|
||||
var anonFunc aghnet.IPMutFunc
|
||||
if config.DNS.AnonymizeClientIP {
|
||||
anonFunc = querylog.AnonymizeIP
|
||||
}
|
||||
anonymizer := aghnet.NewIPMut(anonFunc)
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
@@ -82,34 +80,46 @@ func initDNSServer() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
return initDNSServer(
|
||||
Context.filters,
|
||||
Context.stats,
|
||||
Context.queryLog,
|
||||
Context.dhcpServer,
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
)
|
||||
}
|
||||
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf still must not be nil,
|
||||
// in other cases all the arguments also must not be nil. It also must not be
|
||||
// called unless [config] and [Context] are initialized.
|
||||
func initDNSServer(
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
qlog querylog.QueryLog,
|
||||
dhcpSrv dhcpd.Interface,
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
) (err error) {
|
||||
privateNets, err := parseSubnetSet(config.DNS.PrivateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.filters,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
DHCPServer: Context.dhcpServer,
|
||||
DHCPServer: dhcpSrv,
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||
@@ -120,15 +130,15 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
var dnsConfig dnsforward.ServerConfig
|
||||
dnsConfig, err = generateServerConfig()
|
||||
|
||||
dnsConf, err := generateServerConfig(tlsConf, httpReg)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return fmt.Errorf("generateServerConfig: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Prepare(&dnsConfig)
|
||||
err = Context.dnsServer.Prepare(&dnsConf)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
@@ -146,6 +156,32 @@ func initDNSServer() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
func parseSubnetSet(nets []string) (s netutil.SubnetSet, err error) {
|
||||
switch len(nets) {
|
||||
case 0:
|
||||
// Use an optimized function-based matcher.
|
||||
return netutil.SubnetSetFunc(netutil.IsLocallyServed), nil
|
||||
case 1:
|
||||
s, err = netutil.ParseSubnet(nets[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netutil.SliceSubnetSet(nets), nil
|
||||
}
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
@@ -193,7 +229,10 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
func generateServerConfig(
|
||||
tlsConf *tlsConfigSettings,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
) (newConf dnsforward.ServerConfig, err error) {
|
||||
dnsConf := config.DNS
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
newConf = dnsforward.ServerConfig{
|
||||
@@ -201,12 +240,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
FilteringConfig: dnsConf.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
HTTPRegister: httpReg,
|
||||
OnDNSRequest: onDNSRequest,
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
@@ -224,7 +261,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSCrypt != 0 {
|
||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, tlsConf)
|
||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's already
|
||||
// wrapped by newDNSCrypt.
|
||||
@@ -413,7 +450,11 @@ func startDNSServer() error {
|
||||
|
||||
func reconfigureDNSServer() (err error) {
|
||||
var newConf dnsforward.ServerConfig
|
||||
newConf, err = generateServerConfig()
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
newConf, err = generateServerConfig(tlsConf, httpRegister)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||
}
|
||||
|
||||
@@ -455,6 +455,10 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
err = setupConfig(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config
|
||||
err = config.write()
|
||||
@@ -522,7 +526,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNSServer()
|
||||
err = initDNS()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.start()
|
||||
@@ -543,20 +547,24 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): This could be made much earlier and could be done on
|
||||
// the first run as well, but to achieve this we need to bypass requests
|
||||
// over dnsforward resolver.
|
||||
cmdlineUpdate(opts)
|
||||
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
var anonFunc aghnet.IPMutFunc
|
||||
if c.DNS.AnonymizeClientIP {
|
||||
anonFunc = querylog.AnonymizeIP
|
||||
}
|
||||
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
func startMods() error {
|
||||
err := initDNSServer()
|
||||
func startMods() (err error) {
|
||||
err = initDNS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -927,8 +935,8 @@ func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../dhcpd and
|
||||
// other packages after refactoring the web handler registering.
|
||||
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
|
||||
// packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
@@ -940,30 +948,40 @@ func cmdlineUpdate(opts options) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("starting update")
|
||||
// Initialize the DNS server to use the internal resolver which the updater
|
||||
// needs to be able to resolve the update source hostname.
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{})
|
||||
fatalOnError(err)
|
||||
|
||||
if Context.firstRun {
|
||||
log.Info("update not allowed on first run")
|
||||
log.Info("cmdline update: performing update")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
_, err := Context.updater.VersionInfo(true)
|
||||
updater := Context.updater
|
||||
info, err := updater.VersionInfo(true)
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
vcu := updater.VersionCheckURL()
|
||||
log.Error("getting version info from %s: %s", vcu, err)
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if Context.updater.NewVersion() == "" {
|
||||
if info.NewVersion == version.Version() {
|
||||
log.Info("no updates available")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
err = updater.Update(Context.firstRun)
|
||||
fatalOnError(err)
|
||||
|
||||
err = restartService()
|
||||
if err != nil {
|
||||
log.Debug("restarting service: %s", err)
|
||||
log.Info("AdGuard Home was not installed as a service. " +
|
||||
"Please restart running instances of AdGuardHome manually.")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ type languageJSON struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
// TODO(d.kolyshev): Deprecated, remove it later.
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("home: language is %s", config.Language)
|
||||
|
||||
@@ -62,6 +63,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(d.kolyshev): Deprecated, remove it later.
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
|
||||
@@ -229,7 +229,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.performUpdate },
|
||||
description: "Update application and exit.",
|
||||
description: "Update the current binary and restart the service in case it's installed.",
|
||||
longName: "update",
|
||||
shortName: "",
|
||||
}, {
|
||||
|
||||
102
internal/home/profilehttp.go
Normal file
102
internal/home/profilehttp.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Theme is an enum of all allowed UI themes.
|
||||
type Theme string
|
||||
|
||||
// Allowed [Theme] values.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
const (
|
||||
ThemeAuto Theme = "auto"
|
||||
ThemeLight Theme = "light"
|
||||
ThemeDark Theme = "dark"
|
||||
)
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler] interface for *Theme.
|
||||
func (t *Theme) UnmarshalText(b []byte) (err error) {
|
||||
switch string(b) {
|
||||
case "auto":
|
||||
*t = ThemeAuto
|
||||
case "dark":
|
||||
*t = ThemeDark
|
||||
case "light":
|
||||
*t = ThemeLight
|
||||
default:
|
||||
return fmt.Errorf("invalid theme %q, supported: %q, %q, %q", b, ThemeAuto, ThemeDark, ThemeLight)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// profileJSON is an object for /control/profile and /control/profile/update
|
||||
// endpoints.
|
||||
type profileJSON struct {
|
||||
Name string `json:"name"`
|
||||
Language string `json:"language"`
|
||||
Theme Theme `json:"theme"`
|
||||
}
|
||||
|
||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
|
||||
var resp profileJSON
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
resp = profileJSON{
|
||||
Name: u.Name,
|
||||
Language: config.Language,
|
||||
Theme: config.Theme,
|
||||
}
|
||||
}()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handlePutProfile is the handler for PUT /control/profile/update endpoint.
|
||||
func handlePutProfile(w http.ResponseWriter, r *http.Request) {
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
profileReq := &profileJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(profileReq)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
lang := profileReq.Language
|
||||
if !allowedLanguages.Has(lang) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "unknown language: %q", lang)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
theme := profileReq.Theme
|
||||
|
||||
func() {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
config.Language = lang
|
||||
config.Theme = theme
|
||||
log.Printf("home: language is set to %s", lang)
|
||||
log.Printf("home: theme is set to %s", theme)
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
@@ -159,6 +159,38 @@ func sendSigReload() {
|
||||
log.Debug("service: sent signal to pid %d", pid)
|
||||
}
|
||||
|
||||
// restartService restarts the service. It returns error if the service is not
|
||||
// running.
|
||||
func restartService() (err error) {
|
||||
// Call chooseSystem explicitly to introduce OpenBSD support for service
|
||||
// package. It's a noop for other GOOS values.
|
||||
chooseSystem()
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting current directory: %w", err)
|
||||
}
|
||||
|
||||
svcConfig := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
WorkingDirectory: pwd,
|
||||
}
|
||||
configureService(svcConfig)
|
||||
|
||||
var s service.Service
|
||||
if s, err = service.New(&program{}, svcConfig); err != nil {
|
||||
return fmt.Errorf("initializing service: %w", err)
|
||||
}
|
||||
|
||||
if err = svcAction(s, "restart"); err != nil {
|
||||
return fmt.Errorf("restarting service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleServiceControlAction one of the possible control actions:
|
||||
//
|
||||
// - install: Installs a service/daemon.
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {
|
||||
sys := service.ChosenSystem()
|
||||
// By default, package service uses the SysV system if it cannot detect
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
// sysVersion is the version of local service.System interface implementation.
|
||||
const sysVersion = "openbsd-runcom"
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {
|
||||
service.ChooseSystem(openbsdSystem{})
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ func withRecovered(orig *error) {
|
||||
// type check
|
||||
var _ Interface = (*StatsCtx)(nil)
|
||||
|
||||
// Start implements the Interface interface for *StatsCtx.
|
||||
// Start implements the [Interface] interface for *StatsCtx.
|
||||
func (s *StatsCtx) Start() {
|
||||
s.initWeb()
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
}
|
||||
|
||||
u.prevCheckTime = time.Now()
|
||||
u.prevCheckTime = now
|
||||
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body)
|
||||
|
||||
return u.prevCheckResult, u.prevCheckError
|
||||
@@ -92,7 +92,11 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
info.AnnouncementURL = versionJSON["announcement_url"]
|
||||
|
||||
packageURL, ok := u.downloadURL(versionJSON)
|
||||
info.CanAutoUpdate = aghalg.BoolToNullBool(ok && info.NewVersion != u.version)
|
||||
if !ok {
|
||||
return info, fmt.Errorf("version.json: packageURL not found")
|
||||
}
|
||||
|
||||
info.CanAutoUpdate = aghalg.BoolToNullBool(info.NewVersion != u.version)
|
||||
|
||||
u.newVersion = info.NewVersion
|
||||
u.packageURL = packageURL
|
||||
|
||||
@@ -104,49 +104,58 @@ func NewUpdater(conf *Config) *Updater {
|
||||
}
|
||||
}
|
||||
|
||||
// Update performs the auto-update.
|
||||
func (u *Updater) Update() (err error) {
|
||||
// Update performs the auto-update. It returns an error if the update failed.
|
||||
// If firstRun is true, it assumes the configuration file doesn't exist.
|
||||
func (u *Updater) Update(firstRun bool) (err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
log.Info("updater: updating")
|
||||
defer func() { log.Info("updater: finished; errors: %v", err) }()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Error("updater: failed: %v", err)
|
||||
} else {
|
||||
log.Info("updater: finished")
|
||||
}
|
||||
}()
|
||||
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("getting executable path: %w", err)
|
||||
}
|
||||
|
||||
err = u.prepare(execPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("preparing: %w", err)
|
||||
}
|
||||
|
||||
defer u.clean()
|
||||
|
||||
err = u.downloadPackageFile(u.packageURL, u.packageName)
|
||||
err = u.downloadPackageFile()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("downloading package file: %w", err)
|
||||
}
|
||||
|
||||
err = u.unpack()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("unpacking: %w", err)
|
||||
}
|
||||
|
||||
err = u.check()
|
||||
if err != nil {
|
||||
return err
|
||||
if !firstRun {
|
||||
err = u.check()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = u.backup()
|
||||
err = u.backup(firstRun)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("making backup: %w", err)
|
||||
}
|
||||
|
||||
err = u.replace()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("replacing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -174,7 +183,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
if pkgNameOnly == "" {
|
||||
return fmt.Errorf("invalid PackageURL")
|
||||
return fmt.Errorf("invalid PackageURL: %q", u.packageURL)
|
||||
}
|
||||
|
||||
u.packageName = filepath.Join(u.updateDir, pkgNameOnly)
|
||||
@@ -204,6 +213,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// unpack extracts the files from the downloaded archive.
|
||||
func (u *Updater) unpack() error {
|
||||
var err error
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
@@ -228,38 +238,48 @@ func (u *Updater) unpack() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check returns an error if the configuration file couldn't be used with the
|
||||
// version of AdGuard Home just downloaded.
|
||||
func (u *Updater) check() error {
|
||||
log.Debug("updater: checking configuration")
|
||||
|
||||
err := copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(u.updateExeName, "--check-config")
|
||||
err = cmd.Run()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Updater) backup() error {
|
||||
// backup makes a backup of the current configuration and supporting files. It
|
||||
// ignores the configuration file if firstRun is true.
|
||||
func (u *Updater) backup(firstRun bool) (err error) {
|
||||
log.Debug("updater: backing up current configuration")
|
||||
_ = os.Mkdir(u.backupDir, 0o755)
|
||||
err := copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
if !firstRun {
|
||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
wd := u.workDir
|
||||
err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
|
||||
wd, u.backupDir, err)
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", wd, u.backupDir, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// replace moves the current executable with the updated one and also copies the
|
||||
// supporting files.
|
||||
func (u *Updater) replace() error {
|
||||
err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir)
|
||||
if err != nil {
|
||||
@@ -287,6 +307,7 @@ func (u *Updater) replace() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clean removes the temporary directory itself and all it's contents.
|
||||
func (u *Updater) clean() {
|
||||
_ = os.RemoveAll(u.updateDir)
|
||||
}
|
||||
@@ -297,9 +318,9 @@ func (u *Updater) clean() {
|
||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||
|
||||
// Download package file and save it to disk
|
||||
func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
func (u *Updater) downloadPackageFile() (err error) {
|
||||
var resp *http.Response
|
||||
resp, err = u.client.Get(url)
|
||||
resp, err = u.client.Get(u.packageURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
@@ -321,7 +342,7 @@ func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
_ = os.Mkdir(u.updateDir, 0o755)
|
||||
|
||||
log.Debug("updater: saving package to file")
|
||||
err = os.WriteFile(filename, body, 0o644)
|
||||
err = os.WriteFile(u.packageName, body, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.WriteFile() failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -136,10 +136,10 @@ func TestUpdate(t *testing.T) {
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// require.NoError(t, u.check())
|
||||
require.NoError(t, u.backup())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
@@ -215,10 +215,10 @@ func TestUpdateWindows(t *testing.T) {
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// assert.Nil(t, u.check())
|
||||
require.NoError(t, u.backup())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
|
||||
Reference in New Issue
Block a user