all: sync with master
This commit is contained in:
@@ -14,12 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
const (
|
||||
SchemeHTTP = "http"
|
||||
SchemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
|
||||
@@ -146,16 +146,6 @@ func IsOpenWrt() (ok bool) {
|
||||
return isOpenWrt()
|
||||
}
|
||||
|
||||
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
|
||||
func NotifyReconfigureSignal(c chan<- os.Signal) {
|
||||
notifyReconfigureSignal(c)
|
||||
}
|
||||
|
||||
// IsReconfigureSignal returns true if sig is a reconfigure signal.
|
||||
func IsReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return isReconfigureSignal(sig)
|
||||
}
|
||||
|
||||
// SendShutdownSignal sends the shutdown signal to the channel.
|
||||
func SendShutdownSignal(c chan<- os.Signal) {
|
||||
sendShutdownSignal(c)
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
//go:build unix
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, unix.SIGHUP)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == unix.SIGHUP
|
||||
}
|
||||
|
||||
func sendShutdownSignal(_ chan<- os.Signal) {
|
||||
// On Unix we are already notified by the system.
|
||||
}
|
||||
|
||||
@@ -4,12 +4,11 @@ package aghos
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setRlimit(val uint64) (err error) {
|
||||
func setRlimit(_ uint64) (err error) {
|
||||
return Unsupported("setrlimit")
|
||||
}
|
||||
|
||||
@@ -38,14 +37,6 @@ func isOpenWrt() (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||
signal.Notify(c, windows.SIGHUP)
|
||||
}
|
||||
|
||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
||||
return sig == windows.SIGHUP
|
||||
}
|
||||
|
||||
func sendShutdownSignal(c chan<- os.Signal) {
|
||||
c <- os.Interrupt
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Add platform-independent tests.
|
||||
|
||||
// Chmod is an extension for [os.Chmod] that properly handles Windows access
|
||||
// rights.
|
||||
func Chmod(name string, perm fs.FileMode) (err error) {
|
||||
return chmod(name, perm)
|
||||
}
|
||||
|
||||
// Mkdir is an extension for [os.Mkdir] that properly handles Windows access
|
||||
// rights.
|
||||
func Mkdir(name string, perm fs.FileMode) (err error) {
|
||||
return mkdir(name, perm)
|
||||
}
|
||||
|
||||
// MkdirAll is an extension for [os.MkdirAll] that properly handles Windows
|
||||
// access rights.
|
||||
func MkdirAll(path string, perm fs.FileMode) (err error) {
|
||||
return mkdirAll(path, perm)
|
||||
}
|
||||
|
||||
// WriteFile is an extension for [os.WriteFile] that properly handles Windows
|
||||
// access rights.
|
||||
func WriteFile(filename string, data []byte, perm fs.FileMode) (err error) {
|
||||
return writeFile(filename, data, perm)
|
||||
}
|
||||
|
||||
// OpenFile is an extension for [os.OpenFile] that properly handles Windows
|
||||
// access rights.
|
||||
func OpenFile(name string, flag int, perm fs.FileMode) (file *os.File, err error) {
|
||||
return openFile(name, flag, perm)
|
||||
}
|
||||
|
||||
// Stat is an extension for [os.Stat] that properly handles Windows access
|
||||
// rights.
|
||||
//
|
||||
// Note that on Windows the "other" permission bits combines the access rights
|
||||
// of any trustee that is neither the owner nor the owning group for the file.
|
||||
//
|
||||
// TODO(e.burkov): Inspect the behavior for the World (everyone) well-known
|
||||
// SID and, perhaps, use it.
|
||||
func Stat(name string) (fi fs.FileInfo, err error) {
|
||||
return stat(name)
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
//go:build unix
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
)
|
||||
|
||||
// chmod is a Unix implementation of [Chmod].
|
||||
func chmod(name string, perm fs.FileMode) (err error) {
|
||||
return os.Chmod(name, perm)
|
||||
}
|
||||
|
||||
// mkdir is a Unix implementation of [Mkdir].
|
||||
func mkdir(name string, perm fs.FileMode) (err error) {
|
||||
return os.Mkdir(name, perm)
|
||||
}
|
||||
|
||||
// mkdirAll is a Unix implementation of [MkdirAll].
|
||||
func mkdirAll(path string, perm fs.FileMode) (err error) {
|
||||
return os.MkdirAll(path, perm)
|
||||
}
|
||||
|
||||
// writeFile is a Unix implementation of [WriteFile].
|
||||
func writeFile(filename string, data []byte, perm fs.FileMode) (err error) {
|
||||
return maybe.WriteFile(filename, data, perm)
|
||||
}
|
||||
|
||||
// openFile is a Unix implementation of [OpenFile].
|
||||
func openFile(name string, flag int, perm fs.FileMode) (file *os.File, err error) {
|
||||
// #nosec G304 -- This function simply wraps the [os.OpenFile] function, so
|
||||
// the security concerns should be addressed to the [OpenFile] calls.
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
|
||||
// stat is a Unix implementation of [Stat].
|
||||
func stat(name string) (fi os.FileInfo, err error) {
|
||||
return os.Stat(name)
|
||||
}
|
||||
@@ -1,392 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// fileInfo is a Windows implementation of [fs.FileInfo], that contains the
|
||||
// filemode converted from the security descriptor.
|
||||
type fileInfo struct {
|
||||
// fs.FileInfo is embedded to provide the default implementations and data
|
||||
// successfully retrieved by [os.Stat].
|
||||
fs.FileInfo
|
||||
|
||||
// mode is the file mode converted from the security descriptor.
|
||||
mode fs.FileMode
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.FileInfo = (*fileInfo)(nil)
|
||||
|
||||
// Mode implements [fs.FileInfo.Mode] for [*fileInfo].
|
||||
func (fi *fileInfo) Mode() (mode fs.FileMode) { return fi.mode }
|
||||
|
||||
// stat is a Windows implementation of [Stat].
|
||||
func stat(name string) (fi os.FileInfo, err error) {
|
||||
absName, err := filepath.Abs(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("computing absolute path: %w", err)
|
||||
}
|
||||
|
||||
fi, err = os.Stat(absName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dacl, owner, group, err := retrieveDACL(absName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ownerMask, groupMask, otherMask windows.ACCESS_MASK
|
||||
for i := range uint32(dacl.AceCount) {
|
||||
var ace *windows.ACCESS_ALLOWED_ACE
|
||||
err = windows.GetAce(dacl, i, &ace)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting access control entry at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
entrySid := (*windows.SID)(unsafe.Pointer(&ace.SidStart))
|
||||
switch {
|
||||
case entrySid.Equals(owner):
|
||||
ownerMask |= ace.Mask
|
||||
case entrySid.Equals(group):
|
||||
groupMask |= ace.Mask
|
||||
default:
|
||||
otherMask |= ace.Mask
|
||||
}
|
||||
}
|
||||
|
||||
mode := fi.Mode()
|
||||
perm := masksToPerm(ownerMask, groupMask, otherMask, mode.IsDir())
|
||||
|
||||
return &fileInfo{
|
||||
FileInfo: fi,
|
||||
// Use the file mode from the security descriptor, but use the
|
||||
// calculated permission bits.
|
||||
mode: perm | mode&^fs.FileMode(0o777),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// retrieveDACL retrieves the discretionary access control list, owner, and
|
||||
// group from the security descriptor of the file with the specified absolute
|
||||
// name.
|
||||
func retrieveDACL(absName string) (dacl *windows.ACL, owner, group *windows.SID, err error) {
|
||||
// desiredSecInfo defines the parts of a security descriptor to retrieve.
|
||||
const desiredSecInfo windows.SECURITY_INFORMATION = windows.OWNER_SECURITY_INFORMATION |
|
||||
windows.GROUP_SECURITY_INFORMATION |
|
||||
windows.DACL_SECURITY_INFORMATION |
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION |
|
||||
windows.UNPROTECTED_DACL_SECURITY_INFORMATION
|
||||
|
||||
sd, err := windows.GetNamedSecurityInfo(absName, windows.SE_FILE_OBJECT, desiredSecInfo)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("getting security descriptor: %w", err)
|
||||
}
|
||||
|
||||
dacl, _, err = sd.DACL()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("getting discretionary access control list: %w", err)
|
||||
}
|
||||
|
||||
owner, _, err = sd.Owner()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("getting owner sid: %w", err)
|
||||
}
|
||||
|
||||
group, _, err = sd.Group()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("getting group sid: %w", err)
|
||||
}
|
||||
|
||||
return dacl, owner, group, nil
|
||||
}
|
||||
|
||||
// chmod is a Windows implementation of [Chmod].
|
||||
func chmod(name string, perm fs.FileMode) (err error) {
|
||||
fi, err := os.Stat(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting file info: %w", err)
|
||||
}
|
||||
|
||||
entries := make([]windows.EXPLICIT_ACCESS, 0, 3)
|
||||
creatorMask, groupMask, worldMask := permToMasks(perm, fi.IsDir())
|
||||
|
||||
sidMasks := []struct {
|
||||
Key windows.WELL_KNOWN_SID_TYPE
|
||||
Value windows.ACCESS_MASK
|
||||
}{{
|
||||
Key: windows.WinCreatorOwnerSid,
|
||||
Value: creatorMask,
|
||||
}, {
|
||||
Key: windows.WinCreatorGroupSid,
|
||||
Value: groupMask,
|
||||
}, {
|
||||
Key: windows.WinWorldSid,
|
||||
Value: worldMask,
|
||||
}}
|
||||
|
||||
var errs []error
|
||||
for _, sidMask := range sidMasks {
|
||||
if sidMask.Value == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var trustee windows.TRUSTEE
|
||||
trustee, err = newWellKnownTrustee(sidMask.Key)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
entries = append(entries, windows.EXPLICIT_ACCESS{
|
||||
AccessPermissions: sidMask.Value,
|
||||
AccessMode: windows.GRANT_ACCESS,
|
||||
Inheritance: windows.NO_INHERITANCE,
|
||||
Trustee: trustee,
|
||||
})
|
||||
}
|
||||
|
||||
if err = errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("creating access control entries: %w", err)
|
||||
}
|
||||
|
||||
acl, err := windows.ACLFromEntries(entries, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating access control list: %w", err)
|
||||
}
|
||||
|
||||
// secInfo defines the parts of a security descriptor to set.
|
||||
const secInfo windows.SECURITY_INFORMATION = windows.DACL_SECURITY_INFORMATION |
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION
|
||||
|
||||
err = windows.SetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, secInfo, nil, nil, acl, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting security descriptor: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mkdir is a Windows implementation of [Mkdir].
|
||||
//
|
||||
// TODO(e.burkov): Consider using [windows.CreateDirectory] instead of
|
||||
// [os.Mkdir] to reduce the number of syscalls.
|
||||
func mkdir(name string, perm os.FileMode) (err error) {
|
||||
name, err = filepath.Abs(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("computing absolute path: %w", err)
|
||||
}
|
||||
|
||||
err = os.Mkdir(name, perm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.WithDeferred(err, os.Remove(name))
|
||||
}
|
||||
}()
|
||||
|
||||
return chmod(name, perm)
|
||||
}
|
||||
|
||||
// mkdirAll is a Windows implementation of [MkdirAll].
|
||||
func mkdirAll(path string, perm os.FileMode) (err error) {
|
||||
parent, _ := filepath.Split(path)
|
||||
|
||||
if parent != "" {
|
||||
err = os.MkdirAll(parent, perm)
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return fmt.Errorf("creating parent directories: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = mkdir(path, perm)
|
||||
if errors.Is(err, os.ErrExist) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// writeFile is a Windows implementation of [WriteFile].
|
||||
func writeFile(filename string, data []byte, perm os.FileMode) (err error) {
|
||||
file, err := openFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, file.Close()) }()
|
||||
|
||||
_, err = file.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// openFile is a Windows implementation of [OpenFile].
|
||||
func openFile(name string, flag int, perm os.FileMode) (file *os.File, err error) {
|
||||
// Only change permissions if the file not yet exists, but should be
|
||||
// created.
|
||||
if flag&os.O_CREATE == 0 {
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
|
||||
_, err = stat(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
defer func() { err = errors.WithDeferred(err, chmod(name, perm)) }()
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting file info: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
|
||||
// newWellKnownTrustee returns a trustee for a well-known SID.
|
||||
func newWellKnownTrustee(stype windows.WELL_KNOWN_SID_TYPE) (t windows.TRUSTEE, err error) {
|
||||
sid, err := windows.CreateWellKnownSid(stype)
|
||||
if err != nil {
|
||||
return windows.TRUSTEE{}, fmt.Errorf("creating sid for type %d: %w", stype, err)
|
||||
}
|
||||
|
||||
return windows.TRUSTEE{
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(sid),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UNIX file mode permission bits.
|
||||
const (
|
||||
permRead = 0b100
|
||||
permWrite = 0b010
|
||||
permExecute = 0b001
|
||||
)
|
||||
|
||||
// Windows access masks for appropriate UNIX file mode permission bits and
|
||||
// file types.
|
||||
const (
|
||||
fileReadRights windows.ACCESS_MASK = windows.READ_CONTROL |
|
||||
windows.FILE_READ_DATA |
|
||||
windows.FILE_READ_ATTRIBUTES |
|
||||
windows.FILE_READ_EA |
|
||||
windows.SYNCHRONIZE |
|
||||
windows.ACCESS_SYSTEM_SECURITY
|
||||
|
||||
fileWriteRights windows.ACCESS_MASK = windows.WRITE_DAC |
|
||||
windows.WRITE_OWNER |
|
||||
windows.FILE_WRITE_DATA |
|
||||
windows.FILE_WRITE_ATTRIBUTES |
|
||||
windows.FILE_WRITE_EA |
|
||||
windows.DELETE |
|
||||
windows.FILE_APPEND_DATA |
|
||||
windows.SYNCHRONIZE |
|
||||
windows.ACCESS_SYSTEM_SECURITY
|
||||
|
||||
fileExecuteRights windows.ACCESS_MASK = windows.FILE_EXECUTE
|
||||
|
||||
dirReadRights windows.ACCESS_MASK = windows.READ_CONTROL |
|
||||
windows.FILE_LIST_DIRECTORY |
|
||||
windows.FILE_READ_EA |
|
||||
windows.FILE_READ_ATTRIBUTES<<1 |
|
||||
windows.SYNCHRONIZE |
|
||||
windows.ACCESS_SYSTEM_SECURITY
|
||||
|
||||
dirWriteRights windows.ACCESS_MASK = windows.WRITE_DAC |
|
||||
windows.WRITE_OWNER |
|
||||
windows.DELETE |
|
||||
windows.FILE_WRITE_DATA |
|
||||
windows.FILE_APPEND_DATA |
|
||||
windows.FILE_WRITE_EA |
|
||||
windows.FILE_WRITE_ATTRIBUTES<<1 |
|
||||
windows.SYNCHRONIZE |
|
||||
windows.ACCESS_SYSTEM_SECURITY
|
||||
|
||||
dirExecuteRights windows.ACCESS_MASK = windows.FILE_TRAVERSE
|
||||
)
|
||||
|
||||
// permToMasks converts a UNIX file mode permissions to the corresponding
|
||||
// Windows access masks. The [isDir] argument is used to set specific access
|
||||
// bits for directories.
|
||||
func permToMasks(fm os.FileMode, isDir bool) (owner, group, world windows.ACCESS_MASK) {
|
||||
mask := fm.Perm()
|
||||
|
||||
owner = permToMask(byte((mask>>6)&0b111), isDir)
|
||||
group = permToMask(byte((mask>>3)&0b111), isDir)
|
||||
world = permToMask(byte(mask&0b111), isDir)
|
||||
|
||||
return owner, group, world
|
||||
}
|
||||
|
||||
// permToMask converts a UNIX file mode permission bits within p byte to the
|
||||
// corresponding Windows access mask. The [isDir] argument is used to set
|
||||
// specific access bits for directories.
|
||||
func permToMask(p byte, isDir bool) (mask windows.ACCESS_MASK) {
|
||||
readRights, writeRights, executeRights := fileReadRights, fileWriteRights, fileExecuteRights
|
||||
if isDir {
|
||||
readRights, writeRights, executeRights = dirReadRights, dirWriteRights, dirExecuteRights
|
||||
}
|
||||
|
||||
if p&permRead != 0 {
|
||||
mask |= readRights
|
||||
}
|
||||
if p&permWrite != 0 {
|
||||
mask |= writeRights
|
||||
}
|
||||
if p&permExecute != 0 {
|
||||
mask |= executeRights
|
||||
}
|
||||
|
||||
return mask
|
||||
}
|
||||
|
||||
// masksToPerm converts Windows access masks to the corresponding UNIX file
|
||||
// mode permission bits.
|
||||
func masksToPerm(u, g, o windows.ACCESS_MASK, isDir bool) (perm fs.FileMode) {
|
||||
perm |= fs.FileMode(maskToPerm(u, isDir)) << 6
|
||||
perm |= fs.FileMode(maskToPerm(g, isDir)) << 3
|
||||
perm |= fs.FileMode(maskToPerm(o, isDir))
|
||||
|
||||
return perm
|
||||
}
|
||||
|
||||
// maskToPerm converts a Windows access mask to the corresponding UNIX file
|
||||
// mode permission bits.
|
||||
func maskToPerm(mask windows.ACCESS_MASK, isDir bool) (perm byte) {
|
||||
readMask, writeMask, executeMask := fileReadRights, fileWriteRights, fileExecuteRights
|
||||
if isDir {
|
||||
readMask, writeMask, executeMask = dirReadRights, dirWriteRights, dirExecuteRights
|
||||
}
|
||||
|
||||
// Remove common bits to avoid false positive detection of unset rights.
|
||||
readMask ^= windows.SYNCHRONIZE | windows.ACCESS_SYSTEM_SECURITY
|
||||
writeMask ^= windows.SYNCHRONIZE | windows.ACCESS_SYSTEM_SECURITY
|
||||
|
||||
if mask&readMask != 0 {
|
||||
perm |= permRead
|
||||
}
|
||||
if mask&writeMask != 0 {
|
||||
perm |= permWrite
|
||||
}
|
||||
if mask&executeMask != 0 {
|
||||
perm |= permExecute
|
||||
}
|
||||
|
||||
return perm
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestPermToMasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
perm fs.FileMode
|
||||
wantUser windows.ACCESS_MASK
|
||||
wantGroup windows.ACCESS_MASK
|
||||
wantOther windows.ACCESS_MASK
|
||||
isDir bool
|
||||
}{{
|
||||
name: "all",
|
||||
perm: 0b111_111_111,
|
||||
wantUser: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
wantGroup: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
wantOther: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "user_write",
|
||||
perm: 0b010_000_000,
|
||||
wantUser: fileWriteRights,
|
||||
wantGroup: 0,
|
||||
wantOther: 0,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "group_read",
|
||||
perm: 0b000_100_000,
|
||||
wantUser: 0,
|
||||
wantGroup: fileReadRights,
|
||||
wantOther: 0,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "all_dir",
|
||||
perm: 0b111_111_111,
|
||||
wantUser: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
wantGroup: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
wantOther: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
isDir: true,
|
||||
}, {
|
||||
name: "user_write_dir",
|
||||
perm: 0b010_000_000,
|
||||
wantUser: dirWriteRights,
|
||||
wantGroup: 0,
|
||||
wantOther: 0,
|
||||
isDir: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
user, group, other := permToMasks(tc.perm, tc.isDir)
|
||||
assert.Equal(t, tc.wantUser, user)
|
||||
assert.Equal(t, tc.wantGroup, group)
|
||||
assert.Equal(t, tc.wantOther, other)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMasksToPerm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user windows.ACCESS_MASK
|
||||
group windows.ACCESS_MASK
|
||||
other windows.ACCESS_MASK
|
||||
wantPerm fs.FileMode
|
||||
isDir bool
|
||||
}{{
|
||||
name: "all",
|
||||
user: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
group: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
other: fileReadRights | fileWriteRights | fileExecuteRights,
|
||||
wantPerm: 0b111_111_111,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "user_write",
|
||||
user: fileWriteRights,
|
||||
group: 0,
|
||||
other: 0,
|
||||
wantPerm: 0b010_000_000,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "group_read",
|
||||
user: 0,
|
||||
group: fileReadRights,
|
||||
other: 0,
|
||||
wantPerm: 0b000_100_000,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "no_access",
|
||||
user: 0,
|
||||
group: 0,
|
||||
other: 0,
|
||||
wantPerm: 0,
|
||||
isDir: false,
|
||||
}, {
|
||||
name: "all_dir",
|
||||
user: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
group: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
other: dirReadRights | dirWriteRights | dirExecuteRights,
|
||||
wantPerm: 0b111_111_111,
|
||||
isDir: true,
|
||||
}, {
|
||||
name: "user_write_dir",
|
||||
user: dirWriteRights,
|
||||
group: 0,
|
||||
other: 0,
|
||||
wantPerm: 0b010_000_000,
|
||||
isDir: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Don't call [fs.FileMode.Perm] since the result is expected to
|
||||
// contain only the permission bits.
|
||||
assert.Equal(t, tc.wantPerm, masksToPerm(tc.user, tc.group, tc.other, tc.isDir))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
@@ -63,7 +62,9 @@ func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error
|
||||
return nil, fmt.Errorf("opening pending file: %w", err)
|
||||
}
|
||||
|
||||
err = aghos.Chmod(file.Name(), mode)
|
||||
// TODO(e.burkov): The [os.Chmod] implementation is useless on Windows,
|
||||
// investigate if it can be removed.
|
||||
err = os.Chmod(file.Name(), mode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing pending file: %w", err)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (w *FSWatcher) Add(name string) (err error) {
|
||||
|
||||
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
||||
type ServiceWithConfig[ConfigType any] struct {
|
||||
OnStart func() (err error)
|
||||
OnStart func(ctx context.Context) (err error)
|
||||
OnShutdown func(ctx context.Context) (err error)
|
||||
OnConfig func() (c ConfigType)
|
||||
}
|
||||
@@ -68,8 +68,8 @@ var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
|
||||
|
||||
// Start implements the [agh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[_]) Start() (err error) {
|
||||
return s.OnStart()
|
||||
func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) {
|
||||
return s.OnStart(ctx)
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.ServiceWithConfig] interface for
|
||||
|
||||
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||
@@ -330,12 +330,14 @@ func (ci *index) size() (n int) {
|
||||
// rangeByName is like [Index.Range] but sorts the persistent clients by name
|
||||
// before iterating ensuring a predictable order.
|
||||
func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
|
||||
cs := maps.Values(ci.uidToClient)
|
||||
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
clients := slices.SortedStableFunc(
|
||||
maps.Values(ci.uidToClient),
|
||||
func(a, b *Persistent) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
},
|
||||
)
|
||||
|
||||
for _, c := range cs {
|
||||
for _, c := range clients {
|
||||
if !f(c) {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
@@ -506,7 +505,7 @@ func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
||||
|
||||
// RemoveByName removes persistent client information. ok is false if no such
|
||||
// client exists by that name.
|
||||
func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -516,7 +515,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||
}
|
||||
|
||||
if err := p.CloseUpstreams(); err != nil {
|
||||
log.Error("client storage: removing client %q: %s", p.Name, err)
|
||||
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
s.index.remove(p)
|
||||
|
||||
@@ -735,7 +735,7 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.want(t, s.RemoveByName(tc.cliName))
|
||||
tc.want(t, s.RemoveByName(ctx, tc.cliName))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -744,8 +744,8 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||
err = s.Add(ctx, existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, s.RemoveByName(existingName))
|
||||
assert.False(t, s.RemoveByName(existingName))
|
||||
assert.True(t, s.RemoveByName(ctx, existingName))
|
||||
assert.False(t, s.RemoveByName(ctx, existingName))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -18,9 +18,11 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/go-ping/ping"
|
||||
)
|
||||
|
||||
// v4Server is a DHCPv4 server.
|
||||
|
||||
@@ -82,7 +82,7 @@ type Empty struct{}
|
||||
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
||||
|
||||
// Start implements the [Service] interface for Empty.
|
||||
func (Empty) Start() (err error) { return nil }
|
||||
func (Empty) Start(_ context.Context) (err error) { return nil }
|
||||
|
||||
// Shutdown implements the [Service] interface for Empty.
|
||||
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
|
||||
|
||||
@@ -818,6 +818,8 @@ func (s *Server) proxy() (p *proxy.Proxy) {
|
||||
}
|
||||
|
||||
// Reconfigure applies the new configuration to the DNS server.
|
||||
//
|
||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||
func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
@@ -831,14 +833,15 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
// We wait for some time and hope that this fd will be closed.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||
if s.addrProc != nil {
|
||||
err := s.addrProc.Close()
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing address processor: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conf == nil {
|
||||
conf = &s.conf
|
||||
} else {
|
||||
closeErr := s.addrProc.Close()
|
||||
if closeErr != nil {
|
||||
log.Error("dnsforward: closing address processor: %s", closeErr)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): It seems an error here brings the server down, which is
|
||||
|
||||
@@ -500,6 +500,10 @@ func TestServerRace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
const (
|
||||
googleSafeSearch = "forcesafesearch.google.com."
|
||||
)
|
||||
|
||||
safeSearchConf := filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Google: true,
|
||||
@@ -536,10 +540,17 @@ func TestSafeSearch(t *testing.T) {
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
startDeferStop(t, s)
|
||||
|
||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
pt := testutil.PanicT{}
|
||||
assert.Equal(pt, googleSafeSearch, req.Question[0].Name)
|
||||
|
||||
return aghtest.MatchedResponse(req, dns.TypeA, googleSafeSearch, "1.2.3.4"), nil
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
|
||||
|
||||
startDeferStop(t, s)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
|
||||
|
||||
@@ -585,15 +596,9 @@ func TestSafeSearch(t *testing.T) {
|
||||
t.Run(tc.host, func(t *testing.T) {
|
||||
req := createTestMessage(tc.host)
|
||||
|
||||
// TODO(a.garipov): Create our own helper for this.
|
||||
var reply *dns.Msg
|
||||
once := &sync.Once{}
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
r, _, errExch := client.Exchange(req, addr)
|
||||
if assert.NoError(c, errExch) {
|
||||
once.Do(func() { reply = r })
|
||||
}
|
||||
}, testTimeout*10, testTimeout)
|
||||
reply, err = dns.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.wantCNAME != "" {
|
||||
require.Len(t, reply.Answer, 2)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,7 +33,7 @@ func serveHTTPLocally(t *testing.T, h http.Handler) (urlStr string) {
|
||||
require.IsType(t, (*net.TCPAddr)(nil), addr)
|
||||
|
||||
return (&url.URL{
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
@@ -1057,7 +1057,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
err = aghos.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir)
|
||||
err = os.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir)
|
||||
if err != nil {
|
||||
d.Close()
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ func (d *DNSFilter) validateFilterURL(urlStr string) (err error) {
|
||||
|
||||
if filepath.IsAbs(urlStr) {
|
||||
urlStr = filepath.Clean(urlStr)
|
||||
_, err = aghos.Stat(urlStr)
|
||||
_, err = os.Stat(urlStr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@@ -41,19 +41,14 @@ func (d *DNSFilter) validateFilterURL(urlStr string) (err error) {
|
||||
|
||||
u, err := url.ParseRequestURI(urlStr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
if s := u.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||
return &url.Error{
|
||||
Op: "Check scheme",
|
||||
URL: urlStr,
|
||||
Err: fmt.Errorf("only %v allowed", []string{
|
||||
aghhttp.SchemeHTTP,
|
||||
aghhttp.SchemeHTTPS,
|
||||
}),
|
||||
}
|
||||
err = urlutil.ValidateHTTPURL(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -3,11 +3,12 @@ package rulelist
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/c2h5oh/datasize"
|
||||
@@ -18,6 +19,9 @@ import (
|
||||
//
|
||||
// TODO(a.garipov): Merge with [TextEngine] in some way?
|
||||
type Engine struct {
|
||||
// logger is used to log the operation of the engine and its refreshes.
|
||||
logger *slog.Logger
|
||||
|
||||
// mu protects engine and storage.
|
||||
//
|
||||
// TODO(a.garipov): See if anything else should be protected.
|
||||
@@ -29,8 +33,7 @@ type Engine struct {
|
||||
// storage is the filtering-rule storage. It is saved here to close it.
|
||||
storage *filterlist.RuleStorage
|
||||
|
||||
// name is the human-readable name of the engine, like "allowed", "blocked",
|
||||
// or "custom".
|
||||
// name is the human-readable name of the engine.
|
||||
name string
|
||||
|
||||
// filters is the data about rule filters in this engine.
|
||||
@@ -40,12 +43,15 @@ type Engine struct {
|
||||
// EngineConfig is the configuration for rule-list filtering engines created by
|
||||
// combining refreshable filters.
|
||||
type EngineConfig struct {
|
||||
// Name is the human-readable name of this engine, like "allowed",
|
||||
// "blocked", or "custom".
|
||||
// Logger is used to log the operation of the engine. It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// name is the human-readable name of the engine; see [EngineNameAllow] and
|
||||
// similar constants.
|
||||
Name string
|
||||
|
||||
// Filters is the data about rule lists in this engine. There must be no
|
||||
// other references to the elements of this slice.
|
||||
// other references to the items of this slice. Each item must not be nil.
|
||||
Filters []*Filter
|
||||
}
|
||||
|
||||
@@ -53,6 +59,7 @@ type EngineConfig struct {
|
||||
// refreshed, so a refresh should be performed before use.
|
||||
func NewEngine(c *EngineConfig) (e *Engine) {
|
||||
return &Engine{
|
||||
logger: c.Logger,
|
||||
mu: &sync.RWMutex{},
|
||||
name: c.Name,
|
||||
filters: c.Filters,
|
||||
@@ -85,7 +92,7 @@ func (e *Engine) FilterRequest(
|
||||
}
|
||||
|
||||
// currentEngine returns the current filtering engine.
|
||||
func (e *Engine) currentEngine() (enging *urlfilter.DNSEngine) {
|
||||
func (e *Engine) currentEngine() (engine *urlfilter.DNSEngine) {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
@@ -96,7 +103,7 @@ func (e *Engine) currentEngine() (enging *urlfilter.DNSEngine) {
|
||||
// parseBuf, cli, cacheDir, and maxSize are used for updates of rule-list
|
||||
// filters; see [Filter.Refresh].
|
||||
//
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through enigne
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through engine
|
||||
// tests.
|
||||
func (e *Engine) Refresh(
|
||||
ctx context.Context,
|
||||
@@ -115,20 +122,20 @@ func (e *Engine) Refresh(
|
||||
}
|
||||
|
||||
if len(filtersToRefresh) == 0 {
|
||||
log.Info("filtering: updating engine %q: no rule-list filters", e.name)
|
||||
e.logger.InfoContext(ctx, "updating: no rule-list filters")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
engRefr := &engineRefresh{
|
||||
httpCli: cli,
|
||||
cacheDir: cacheDir,
|
||||
engineName: e.name,
|
||||
parseBuf: parseBuf,
|
||||
maxSize: maxSize,
|
||||
logger: e.logger,
|
||||
httpCli: cli,
|
||||
cacheDir: cacheDir,
|
||||
parseBuf: parseBuf,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
|
||||
ruleLists, errs := engRefr.process(ctx, e.filters)
|
||||
ruleLists, errs := engRefr.process(ctx, filtersToRefresh)
|
||||
if isOneTimeoutError(errs) {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@@ -141,14 +148,14 @@ func (e *Engine) Refresh(
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
e.resetStorage(storage)
|
||||
e.resetStorage(ctx, storage)
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// resetStorage sets e.storage and e.engine and closes the previous storage.
|
||||
// Errors from closing the previous storage are logged.
|
||||
func (e *Engine) resetStorage(storage *filterlist.RuleStorage) {
|
||||
func (e *Engine) resetStorage(ctx context.Context, storage *filterlist.RuleStorage) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
@@ -161,7 +168,7 @@ func (e *Engine) resetStorage(storage *filterlist.RuleStorage) {
|
||||
|
||||
err := prevStorage.Close()
|
||||
if err != nil {
|
||||
log.Error("filtering: engine %q: closing old storage: %s", e.name, err)
|
||||
e.logger.WarnContext(ctx, "closing old storage", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,11 +186,11 @@ func isOneTimeoutError(errs []error) (ok bool) {
|
||||
|
||||
// engineRefresh represents a single ongoing engine refresh.
|
||||
type engineRefresh struct {
|
||||
httpCli *http.Client
|
||||
cacheDir string
|
||||
engineName string
|
||||
parseBuf []byte
|
||||
maxSize datasize.ByteSize
|
||||
logger *slog.Logger
|
||||
httpCli *http.Client
|
||||
cacheDir string
|
||||
parseBuf []byte
|
||||
maxSize datasize.ByteSize
|
||||
}
|
||||
|
||||
// process runs updates of all given rule-list filters. All errors are logged
|
||||
@@ -216,12 +223,12 @@ func (r *engineRefresh) process(
|
||||
errs = append(errs, err)
|
||||
|
||||
// Also log immediately, since the update can take a lot of time.
|
||||
log.Error(
|
||||
"filtering: updating engine %q: rule list %s from url %q: %s\n",
|
||||
r.engineName,
|
||||
f.uid,
|
||||
f.url,
|
||||
err,
|
||||
r.logger.ErrorContext(
|
||||
ctx,
|
||||
"updating rule list",
|
||||
"uid", f.uid,
|
||||
"url", f.url,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -237,17 +244,17 @@ func (r *engineRefresh) processFilter(ctx context.Context, f *Filter) (err error
|
||||
}
|
||||
|
||||
if prevChecksum == parseRes.Checksum {
|
||||
log.Info("filtering: engine %q: filter %q: no change", r.engineName, f.uid)
|
||||
r.logger.InfoContext(ctx, "no change in filter", "uid", f.uid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"filtering: updated engine %q: filter %q: %d bytes, %d rules",
|
||||
r.engineName,
|
||||
f.uid,
|
||||
parseRes.BytesWritten,
|
||||
parseRes.RulesCount,
|
||||
r.logger.InfoContext(
|
||||
ctx,
|
||||
"filter updated",
|
||||
"uid", f.uid,
|
||||
"bytes", parseRes.BytesWritten,
|
||||
"rules", parseRes.RulesCount,
|
||||
)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
)
|
||||
|
||||
func TestEngine_Refresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
fileURL, srvURL := newFilterLocations(t, cacheDir, testRuleTextBlocked, testRuleTextBlocked2)
|
||||
@@ -21,6 +24,7 @@ func TestEngine_Refresh(t *testing.T) {
|
||||
httpFlt := newFilter(t, srvURL, "HTTP Filter")
|
||||
|
||||
eng := rulelist.NewEngine(&rulelist.EngineConfig{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Name: "Engine",
|
||||
Filters: []*rulelist.Filter{fileFlt, httpFlt},
|
||||
})
|
||||
|
||||
@@ -105,7 +105,7 @@ func NewFilter(c *FilterConfig) (f *Filter, err error) {
|
||||
// buffer used to parse information from the data. cli and maxSize are only
|
||||
// used when f is a URL-based list.
|
||||
//
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through enigne
|
||||
// TODO(a.garipov): Unexport and test in an internal test or through engine
|
||||
// tests.
|
||||
//
|
||||
// TODO(a.garipov): Consider not returning parseRes.
|
||||
|
||||
@@ -8,12 +8,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFilter_Refresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
uid := rulelist.MustNewUID()
|
||||
|
||||
@@ -37,7 +40,7 @@ func TestFilter_Refresh(t *testing.T) {
|
||||
}, {
|
||||
name: "file",
|
||||
url: &url.URL{
|
||||
Scheme: "file",
|
||||
Scheme: urlutil.SchemeFile,
|
||||
Path: fileURL.Path,
|
||||
},
|
||||
wantNewErrMsg: "",
|
||||
@@ -49,6 +52,8 @@ func TestFilter_Refresh(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f, err := rulelist.NewFilter(&rulelist.FilterConfig{
|
||||
URL: tc.url,
|
||||
Name: tc.name,
|
||||
|
||||
@@ -71,3 +71,10 @@ var _ fmt.Stringer = UID{}
|
||||
func (id UID) String() (s string) {
|
||||
return uuid.UUID(id).String()
|
||||
}
|
||||
|
||||
// Common engine names.
|
||||
const (
|
||||
EngineNameAllow = "allow"
|
||||
EngineNameBlock = "block"
|
||||
EngineNameCustom = "custom"
|
||||
)
|
||||
|
||||
@@ -6,20 +6,16 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
@@ -31,6 +27,7 @@ const testTitle = "Test Title"
|
||||
|
||||
// Common rule texts for tests.
|
||||
const (
|
||||
testRuleTextAllowed = "||allowed.example^\n"
|
||||
testRuleTextBadTab = "||bad-tab-and-comment.example^\t# A comment.\n"
|
||||
testRuleTextBlocked = "||blocked.example^\n"
|
||||
testRuleTextBlocked2 = "||blocked-2.example^\n"
|
||||
@@ -79,8 +76,16 @@ func newFilterLocations(
|
||||
fileData string,
|
||||
httpData string,
|
||||
) (fileURL, srvURL *url.URL) {
|
||||
filePath := filepath.Join(cacheDir, "initial.txt")
|
||||
err := os.WriteFile(filePath, []byte(fileData), 0o644)
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(cacheDir, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
filePath := f.Name()
|
||||
err = os.WriteFile(filePath, []byte(fileData), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
@@ -88,7 +93,7 @@ func newFilterLocations(
|
||||
})
|
||||
|
||||
fileURL = &url.URL{
|
||||
Scheme: "file",
|
||||
Scheme: urlutil.SchemeFile,
|
||||
Path: filePath,
|
||||
}
|
||||
|
||||
|
||||
112
internal/filtering/rulelist/storage.go
Normal file
112
internal/filtering/rulelist/storage.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package rulelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// Storage contains the main filtering engines, including the allowlist, the
|
||||
// blocklist, and the user's custom filtering rules.
|
||||
type Storage struct {
|
||||
// refreshMu makes sure that only one update takes place at a time.
|
||||
refreshMu *sync.Mutex
|
||||
|
||||
allow *Engine
|
||||
block *Engine
|
||||
custom *TextEngine
|
||||
httpCli *http.Client
|
||||
cacheDir string
|
||||
parseBuf []byte
|
||||
maxSize datasize.ByteSize
|
||||
}
|
||||
|
||||
// StorageConfig is the configuration for the filtering-engine storage.
|
||||
type StorageConfig struct {
|
||||
// Logger is used to log the operation of the storage. It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// HTTPClient is the HTTP client used to perform updates of rule lists.
|
||||
// It must not be nil.
|
||||
HTTPClient *http.Client
|
||||
|
||||
// CacheDir is the path to the directory used to cache rule-list files.
|
||||
// It must be set.
|
||||
CacheDir string
|
||||
|
||||
// AllowFilters are the filtering-rule lists used to exclude domain names
|
||||
// from the filtering. Each item must not be nil.
|
||||
AllowFilters []*Filter
|
||||
|
||||
// BlockFilters are the filtering-rule lists used to block domain names.
|
||||
// Each item must not be nil.
|
||||
BlockFilters []*Filter
|
||||
|
||||
// CustomRules contains custom rules of the user. They have priority over
|
||||
// both allow- and blacklist rules.
|
||||
CustomRules []string
|
||||
|
||||
// MaxRuleListTextSize is the maximum size of a rule-list file. It must be
|
||||
// greater than zero.
|
||||
MaxRuleListTextSize datasize.ByteSize
|
||||
}
|
||||
|
||||
// NewStorage creates a new filtering-engine storage. The engines are not
|
||||
// refreshed, so a refresh should be performed before use.
|
||||
func NewStorage(c *StorageConfig) (s *Storage, err error) {
|
||||
custom, err := NewTextEngine(&TextEngineConfig{
|
||||
Name: EngineNameCustom,
|
||||
Rules: c.CustomRules,
|
||||
ID: URLFilterIDCustom,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating custom engine: %w", err)
|
||||
}
|
||||
|
||||
return &Storage{
|
||||
refreshMu: &sync.Mutex{},
|
||||
allow: NewEngine(&EngineConfig{
|
||||
Logger: c.Logger.With("engine", EngineNameAllow),
|
||||
Name: EngineNameAllow,
|
||||
Filters: c.AllowFilters,
|
||||
}),
|
||||
block: NewEngine(&EngineConfig{
|
||||
Logger: c.Logger.With("engine", EngineNameBlock),
|
||||
Name: EngineNameBlock,
|
||||
Filters: c.BlockFilters,
|
||||
}),
|
||||
custom: custom,
|
||||
httpCli: c.HTTPClient,
|
||||
cacheDir: c.CacheDir,
|
||||
parseBuf: make([]byte, DefaultRuleBufSize),
|
||||
maxSize: c.MaxRuleListTextSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying rule-list engines.
|
||||
func (s *Storage) Close() (err error) {
|
||||
// Don't wrap the errors since they are informative enough as is.
|
||||
return errors.Join(
|
||||
s.allow.Close(),
|
||||
s.block.Close(),
|
||||
)
|
||||
}
|
||||
|
||||
// Refresh updates all engines in s.
|
||||
//
|
||||
// TODO(a.garipov): Refresh allow and block separately?
|
||||
func (s *Storage) Refresh(ctx context.Context) (err error) {
|
||||
s.refreshMu.Lock()
|
||||
defer s.refreshMu.Unlock()
|
||||
|
||||
// Don't wrap the errors since they are informative enough as is.
|
||||
return errors.Join(
|
||||
s.allow.Refresh(ctx, s.parseBuf, s.httpCli, s.cacheDir, s.maxSize),
|
||||
s.block.Refresh(ctx, s.parseBuf, s.httpCli, s.cacheDir, s.maxSize),
|
||||
)
|
||||
}
|
||||
49
internal/filtering/rulelist/storage_test.go
Normal file
49
internal/filtering/rulelist/storage_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package rulelist_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStorage_Refresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
allowedFileURL, _ := newFilterLocations(t, cacheDir, testRuleTextAllowed, "")
|
||||
allowedFlt := newFilter(t, allowedFileURL, "Allowed 1")
|
||||
|
||||
blockedFileURL, _ := newFilterLocations(t, cacheDir, testRuleTextBlocked, "")
|
||||
blockedFlt := newFilter(t, blockedFileURL, "Blocked 1")
|
||||
|
||||
strg, err := rulelist.NewStorage(&rulelist.StorageConfig{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: testTimeout,
|
||||
},
|
||||
CacheDir: cacheDir,
|
||||
AllowFilters: []*rulelist.Filter{
|
||||
allowedFlt,
|
||||
},
|
||||
BlockFilters: []*rulelist.Filter{
|
||||
blockedFlt,
|
||||
},
|
||||
CustomRules: []string{
|
||||
testRuleTextBlocked2,
|
||||
},
|
||||
MaxRuleListTextSize: 1 * datasize.KB,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, strg.Close)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
err = strg.Refresh(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -20,15 +20,15 @@ type TextEngine struct {
|
||||
// storage is the filtering-rule storage. It is saved here to close it.
|
||||
storage *filterlist.RuleStorage
|
||||
|
||||
// name is the human-readable name of the engine, like "custom".
|
||||
// name is the human-readable name of the engine.
|
||||
name string
|
||||
}
|
||||
|
||||
// TextEngineConfig is the configuration for a rule-list filtering engine
|
||||
// created from a filtering rule text.
|
||||
type TextEngineConfig struct {
|
||||
// Name is the human-readable name of this engine, like "allowed",
|
||||
// "blocked", or "custom".
|
||||
// name is the human-readable name of the engine; see [EngineNameAllow] and
|
||||
// similar constants.
|
||||
Name string
|
||||
|
||||
// Rules is the text of the filtering rules for this engine.
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
)
|
||||
|
||||
func TestNewTextEngine(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eng, err := rulelist.NewTextEngine(&rulelist.TextEngineConfig{
|
||||
Name: "RulesEngine",
|
||||
Rules: []string{
|
||||
|
||||
@@ -91,10 +91,7 @@ func InitAuth(
|
||||
}
|
||||
var err error
|
||||
|
||||
opts := *bbolt.DefaultOptions
|
||||
opts.OpenFile = aghos.OpenFile
|
||||
|
||||
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, &opts)
|
||||
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
|
||||
if err != nil {
|
||||
log.Error("auth: open DB: %s: %s", dbFilename, err)
|
||||
if err.Error() == "invalid argument" {
|
||||
|
||||
@@ -369,7 +369,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.storage.RemoveByName(cj.Name) {
|
||||
if !clients.storage.RemoveByName(r.Context(), cj.Name) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
|
||||
|
||||
return
|
||||
|
||||
@@ -162,6 +162,12 @@ type configuration struct {
|
||||
// SchemaVersion is the version of the configuration schema. See
|
||||
// [configmigrate.LastSchemaVersion].
|
||||
SchemaVersion uint `yaml:"schema_version"`
|
||||
|
||||
// UnsafeUseCustomUpdateIndexURL is the URL to the custom update index.
|
||||
//
|
||||
// NOTE: It's only exists for testing purposes and should not be used in
|
||||
// release.
|
||||
UnsafeUseCustomUpdateIndexURL bool `yaml:"unsafe_use_custom_update_index_url,omitempty"`
|
||||
}
|
||||
|
||||
// httpConfig is a block with HTTP configuration params.
|
||||
@@ -708,7 +714,7 @@ func (c *configuration) write() (err error) {
|
||||
return fmt.Errorf("generating config file: %w", err)
|
||||
}
|
||||
|
||||
err = aghos.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile)
|
||||
err = maybe.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing config file: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
|
||||
@@ -376,7 +377,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
//
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
||||
originURL := &url.URL{
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: r.Host,
|
||||
}
|
||||
|
||||
@@ -395,7 +396,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Scheme: urlutil.SchemeHTTPS,
|
||||
Host: hostPort,
|
||||
Path: u.Path,
|
||||
RawQuery: u.RawQuery,
|
||||
|
||||
@@ -75,30 +75,31 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
// update server.
|
||||
func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
updater := web.conf.updater
|
||||
for i := 0; i != 3; i++ {
|
||||
for range 3 {
|
||||
resp.VersionInfo, err = updater.VersionInfo(recheck)
|
||||
if err != nil {
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
// Temporary network error. This case may happen while we're
|
||||
// restarting our DNS server. Log and sleep for some time.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
d := time.Duration(i) * time.Second
|
||||
log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
time.Sleep(d)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
// Temporary network error. This case may happen while we're
|
||||
// restarting our DNS server. Log and sleep for some time.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
const sleepTime = 2 * time.Second
|
||||
|
||||
log.Info("update: temp net error: %v; sleeping for %s and retrying", err, sleepTime)
|
||||
time.Sleep(sleepTime)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vcu := updater.VersionCheckURL()
|
||||
|
||||
return fmt.Errorf("getting version info from %s: %w", vcu, err)
|
||||
return fmt.Errorf("getting version info: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -47,11 +48,11 @@ func onConfigModified() {
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized. l must not be nil.
|
||||
func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "stats"),
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
@@ -72,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
@@ -112,7 +114,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
l,
|
||||
baseLogger,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -371,7 +373,7 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
}
|
||||
|
||||
de.https = (&url.URL{
|
||||
Scheme: "https",
|
||||
Scheme: urlutil.SchemeHTTPS,
|
||||
Host: addr,
|
||||
Path: "/dns-query",
|
||||
}).String()
|
||||
@@ -456,7 +458,8 @@ func startDNSServer() error {
|
||||
Context.filters.EnableFilters(false)
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.clients.Start(context.TODO())
|
||||
ctx := context.TODO()
|
||||
err := Context.clients.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
@@ -468,7 +471,11 @@ func startDNSServer() error {
|
||||
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
||||
err = Context.queryLog.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting query log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -524,12 +531,16 @@ func closeDNSServer() {
|
||||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
log.Error("closing stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
Context.queryLog.Close()
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.queryLog.Shutdown(context.TODO())
|
||||
if err != nil {
|
||||
log.Error("closing query log: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("all dns modules are closed")
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
@@ -42,6 +40,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
|
||||
@@ -159,7 +158,7 @@ func setupContext(opts options) (err error) {
|
||||
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
checkNetworkPermissions()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -495,11 +494,42 @@ func checkPorts() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUpdateEnabled returns true if the update is enabled for current
|
||||
// configuration. It also logs the decision. customURL should be true if the
|
||||
// updater is using a custom URL.
|
||||
func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customURL bool) (ok bool) {
|
||||
if opts.disableUpdate {
|
||||
l.DebugContext(ctx, "updates are disabled by command-line option")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
switch version.Channel() {
|
||||
case
|
||||
version.ChannelDevelopment,
|
||||
version.ChannelCandidate:
|
||||
if customURL {
|
||||
l.DebugContext(ctx, "updates are enabled because custom url is used")
|
||||
} else {
|
||||
l.DebugContext(ctx, "updates are disabled for development and candidate builds")
|
||||
}
|
||||
|
||||
return customURL
|
||||
default:
|
||||
l.DebugContext(ctx, "updates are enabled")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// initWeb initializes the web module.
|
||||
func initWeb(
|
||||
ctx context.Context,
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
upd *updater.Updater,
|
||||
l *slog.Logger,
|
||||
customURL bool,
|
||||
) (web *webAPI, err error) {
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
@@ -513,17 +543,7 @@ func initWeb(
|
||||
}
|
||||
}
|
||||
|
||||
disableUpdate := opts.disableUpdate
|
||||
switch version.Channel() {
|
||||
case
|
||||
version.ChannelDevelopment,
|
||||
version.ChannelCandidate:
|
||||
disableUpdate = true
|
||||
}
|
||||
|
||||
if disableUpdate {
|
||||
log.Info("AdGuard Home updates are disabled")
|
||||
}
|
||||
disableUpdate := !isUpdateEnabled(ctx, l, &opts, customURL)
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
@@ -544,7 +564,7 @@ func initWeb(
|
||||
|
||||
web = newWebAPI(webConf, l)
|
||||
if web == nil {
|
||||
return nil, fmt.Errorf("initializing web: %w", err)
|
||||
return nil, errors.Error("can not initialize web")
|
||||
}
|
||||
|
||||
return web, nil
|
||||
@@ -557,6 +577,8 @@ func fatalOnError(err error) {
|
||||
}
|
||||
|
||||
// run configures and starts AdGuard Home.
|
||||
//
|
||||
// TODO(e.burkov): Make opts a pointer.
|
||||
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// Configure working dir.
|
||||
err := initWorkingDir(opts)
|
||||
@@ -604,33 +626,13 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
execPath, err := os.Executable()
|
||||
fatalOnError(errors.Annotate(err, "getting executable path: %w"))
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
||||
}
|
||||
|
||||
confPath := configFilePath()
|
||||
log.Debug("using config path %q for updater", confPath)
|
||||
|
||||
upd := updater.NewUpdater(&updater.Config{
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: confPath,
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: u.String(),
|
||||
})
|
||||
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config)
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts, upd, slogLogger)
|
||||
cmdlineUpdate(ctx, slogLogger, opts, upd)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -643,7 +645,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
}
|
||||
|
||||
dataDir := Context.getDataDir()
|
||||
err = aghos.MkdirAll(dataDir, aghos.DefaultPermDir)
|
||||
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
|
||||
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
|
||||
|
||||
GLMode = opts.glinetMode
|
||||
@@ -658,7 +660,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd, slogLogger)
|
||||
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
|
||||
fatalOnError(err)
|
||||
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
@@ -686,18 +688,87 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
if permcheck.NeedsMigration(confPath) {
|
||||
permcheck.Migrate(Context.workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
if !opts.noPermCheck {
|
||||
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
}
|
||||
|
||||
permcheck.Check(Context.workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
|
||||
Context.web.start()
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-done
|
||||
}
|
||||
|
||||
// newUpdater creates a new AdGuard Home updater. customURL is true if the user
|
||||
// has specified a custom version announcement URL.
|
||||
func newUpdater(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
workDir string,
|
||||
confPath string,
|
||||
execPath string,
|
||||
config *configuration,
|
||||
) (upd *updater.Updater, customURL bool) {
|
||||
// envName is the name of the environment variable that can be used to
|
||||
// override the default version check URL.
|
||||
const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL"
|
||||
|
||||
customURLStr := os.Getenv(envName)
|
||||
|
||||
var versionURL *url.URL
|
||||
switch {
|
||||
case version.Channel() == version.ChannelRelease:
|
||||
// Only enable custom version URL for development builds.
|
||||
l.DebugContext(ctx, "custom version url is disabled for release builds")
|
||||
case !config.UnsafeUseCustomUpdateIndexURL:
|
||||
l.DebugContext(ctx, "custom version url is disabled in config")
|
||||
default:
|
||||
versionURL, _ = url.Parse(customURLStr)
|
||||
}
|
||||
|
||||
err := urlutil.ValidateHTTPURL(versionURL)
|
||||
if customURL = err == nil; !customURL {
|
||||
l.DebugContext(ctx, "parsing custom version url", slogutil.KeyError, err)
|
||||
|
||||
versionURL = updater.DefaultVersionURL()
|
||||
}
|
||||
|
||||
l.DebugContext(ctx, "creating updater", "config_path", confPath)
|
||||
|
||||
return updater.NewUpdater(&updater.Config{
|
||||
Client: config.Filtering.HTTPClient,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: workDir,
|
||||
ConfName: confPath,
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: versionURL,
|
||||
}), customURL
|
||||
}
|
||||
|
||||
// checkPermissions checks and migrates permissions of the files and directories
|
||||
// used by AdGuard Home, if needed.
|
||||
func checkPermissions(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
workDir string,
|
||||
confPath string,
|
||||
dataDir string,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
) {
|
||||
l := baseLogger.With(slogutil.KeyPrefix, "permcheck")
|
||||
|
||||
if permcheck.NeedsMigration(ctx, l, workDir, confPath) {
|
||||
permcheck.Migrate(ctx, l, workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
}
|
||||
|
||||
permcheck.Check(ctx, l, workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
}
|
||||
|
||||
// initUsers initializes context auth module. Clears config users field.
|
||||
func initUsers() (auth *Auth, err error) {
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
@@ -757,8 +828,9 @@ func startMods(l *slog.Logger) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the current user permissions are enough to run AdGuard Home
|
||||
func checkPermissions() {
|
||||
// checkNetworkPermissions checks if the current user permissions are enough to
|
||||
// use the required networking functionality.
|
||||
func checkNetworkPermissions() {
|
||||
log.Info("Checking if AdGuard Home has necessary permissions")
|
||||
|
||||
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
|
||||
@@ -936,12 +1008,12 @@ func printHTTPAddresses(proto string) {
|
||||
}
|
||||
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
if proto == urlutil.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||
|
||||
return
|
||||
@@ -1001,7 +1073,7 @@ type jsonError struct {
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits. l must not be nil.
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updater.Updater) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
@@ -1014,20 +1086,19 @@ func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l)
|
||||
fatalOnError(err)
|
||||
|
||||
log.Info("cmdline update: performing update")
|
||||
l.InfoContext(ctx, "performing update via cli")
|
||||
|
||||
info, err := upd.VersionInfo(true)
|
||||
if err != nil {
|
||||
vcu := upd.VersionCheckURL()
|
||||
log.Error("getting version info from %s: %s", vcu, err)
|
||||
l.ErrorContext(ctx, "getting version info", slogutil.KeyError, err)
|
||||
|
||||
os.Exit(1)
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
}
|
||||
|
||||
if info.NewVersion == version.Version() {
|
||||
log.Info("no updates available")
|
||||
l.InfoContext(ctx, "no updates available")
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
err = upd.Update(Context.firstRun)
|
||||
@@ -1035,10 +1106,10 @@ func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
|
||||
err = restartService()
|
||||
if err != nil {
|
||||
log.Debug("restarting service: %s", err)
|
||||
log.Info("AdGuard Home was not installed as a service. " +
|
||||
l.DebugContext(ctx, "restarting service", slogutil.KeyError, err)
|
||||
l.InfoContext(ctx, "AdGuard Home was not installed as a service. "+
|
||||
"Please restart running instances of AdGuardHome manually.")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/google/uuid"
|
||||
"howett.net/plist"
|
||||
)
|
||||
@@ -84,7 +84,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||
case dnsProtoHTTPS:
|
||||
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
u := &url.URL{
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Scheme: urlutil.SchemeHTTPS,
|
||||
Host: d.ServerName,
|
||||
Path: path.Join("/dns-query", clientID),
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ type options struct {
|
||||
// localFrontend forces AdGuard Home to use the frontend files from disk
|
||||
// rather than the ones that have been compiled into the binary.
|
||||
localFrontend bool
|
||||
|
||||
// noPermCheck disables checking and migration of permissions for the
|
||||
// security-sensitive files.
|
||||
noPermCheck bool
|
||||
}
|
||||
|
||||
// initCmdLineOpts completes initialization of the global command-line option
|
||||
@@ -305,6 +309,15 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
description: "Run in GL-Inet compatibility mode.",
|
||||
longName: "glinet",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.noPermCheck = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.noPermCheck },
|
||||
description: "Skip checking and migration of permissions " +
|
||||
"of security-sensitive files.",
|
||||
longName: "no-permcheck",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@@ -336,7 +336,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
|
||||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.
|
||||
AdGuard Home is now available at the following addresses:`)
|
||||
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||
printHTTPAddresses(urlutil.SchemeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -101,6 +101,8 @@ type webAPI struct {
|
||||
|
||||
// newWebAPI creates a new instance of the web UI and API server. l must not be
|
||||
// nil.
|
||||
//
|
||||
// TODO(a.garipov): Return a proper error.
|
||||
func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) {
|
||||
log.Info("web: initializing")
|
||||
|
||||
@@ -192,7 +194,7 @@ func (web *webAPI) start() {
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.inShutdown {
|
||||
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||
printHTTPAddresses(urlutil.SchemeHTTP)
|
||||
errs := make(chan error, 2)
|
||||
|
||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||
@@ -286,7 +288,7 @@ func (web *webAPI) tlsServerLoop() {
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
}
|
||||
|
||||
printHTTPAddresses(aghhttp.SchemeHTTPS)
|
||||
printHTTPAddresses(urlutil.SchemeHTTPS)
|
||||
|
||||
if web.conf.serveHTTP3 {
|
||||
go web.mustStartHTTP3(addr)
|
||||
|
||||
@@ -1,36 +1,9 @@
|
||||
// Package agh contains common entities and interfaces of AdGuard Home.
|
||||
package agh
|
||||
|
||||
import "context"
|
||||
|
||||
// Service is the interface for API servers.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding a context to Start.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding a Wait method or making an extension
|
||||
// interface for that.
|
||||
type Service interface {
|
||||
// Start starts the service. It does not block.
|
||||
Start() (err error)
|
||||
|
||||
// Shutdown gracefully stops the service. ctx is used to determine
|
||||
// a timeout before trying to stop the service less gracefully.
|
||||
Shutdown(ctx context.Context) (err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Service = EmptyService{}
|
||||
|
||||
// EmptyService is a [Service] that does nothing.
|
||||
//
|
||||
// TODO(a.garipov): Remove if unnecessary.
|
||||
type EmptyService struct{}
|
||||
|
||||
// Start implements the [Service] interface for EmptyService.
|
||||
func (EmptyService) Start() (err error) { return nil }
|
||||
|
||||
// Shutdown implements the [Service] interface for EmptyService.
|
||||
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
|
||||
import (
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
)
|
||||
|
||||
// ServiceWithConfig is an extension of the [Service] interface for services
|
||||
// that can return their configuration.
|
||||
@@ -38,7 +11,7 @@ func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
|
||||
// TODO(a.garipov): Consider removing this generic interface if we figure out
|
||||
// how to make it testable in a better way.
|
||||
type ServiceWithConfig[ConfigType any] interface {
|
||||
Service
|
||||
service.Interface
|
||||
|
||||
Config() (c ConfigType)
|
||||
}
|
||||
@@ -51,7 +24,7 @@ var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
|
||||
//
|
||||
// TODO(a.garipov): Remove if unnecessary.
|
||||
type EmptyServiceWithConfig[ConfigType any] struct {
|
||||
EmptyService
|
||||
service.Empty
|
||||
|
||||
Conf ConfigType
|
||||
}
|
||||
|
||||
@@ -12,11 +12,15 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
)
|
||||
|
||||
// Main is the entry point of AdGuard Home.
|
||||
func Main(embeddedFrontend fs.FS) {
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
cmdName := os.Args[0]
|
||||
@@ -26,70 +30,69 @@ func Main(embeddedFrontend fs.FS) {
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
err = setLog(opts)
|
||||
check(err)
|
||||
baseLogger := newBaseLogger(opts)
|
||||
|
||||
log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid())
|
||||
baseLogger.InfoContext(
|
||||
ctx,
|
||||
"starting adguard home",
|
||||
"version", version.Version(),
|
||||
"pid", os.Getpid(),
|
||||
)
|
||||
|
||||
if opts.workDir != "" {
|
||||
log.Info("changing working directory to %q", opts.workDir)
|
||||
baseLogger.InfoContext(ctx, "changing working directory", "dir", opts.workDir)
|
||||
|
||||
err = os.Chdir(opts.workDir)
|
||||
check(err)
|
||||
errors.Check(err)
|
||||
}
|
||||
|
||||
frontend, err := frontendFromOpts(opts, embeddedFrontend)
|
||||
check(err)
|
||||
frontend, err := frontendFromOpts(ctx, baseLogger, opts, embeddedFrontend)
|
||||
errors.Check(err)
|
||||
|
||||
startCtx, startCancel := context.WithTimeout(ctx, defaultTimeoutStart)
|
||||
defer startCancel()
|
||||
|
||||
confMgrConf := &configmgr.Config{
|
||||
Frontend: frontend,
|
||||
WebAddr: opts.webAddr,
|
||||
Start: start,
|
||||
FileName: opts.confFile,
|
||||
BaseLogger: baseLogger,
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "configmgr"),
|
||||
Frontend: frontend,
|
||||
WebAddr: opts.webAddr,
|
||||
Start: start,
|
||||
FileName: opts.confFile,
|
||||
}
|
||||
|
||||
confMgr, err := newConfigMgr(confMgrConf)
|
||||
check(err)
|
||||
confMgr, err := configmgr.New(startCtx, confMgrConf)
|
||||
errors.Check(err)
|
||||
|
||||
web := confMgr.Web()
|
||||
err = web.Start()
|
||||
check(err)
|
||||
err = web.Start(startCtx)
|
||||
errors.Check(err)
|
||||
|
||||
dns := confMgr.DNS()
|
||||
err = dns.Start()
|
||||
check(err)
|
||||
err = dns.Start(startCtx)
|
||||
errors.Check(err)
|
||||
|
||||
sigHdlr := newSignalHandler(
|
||||
baseLogger.With(slogutil.KeyPrefix, service.SignalHandlerPrefix),
|
||||
confMgrConf,
|
||||
opts.pidFile,
|
||||
web,
|
||||
dns,
|
||||
)
|
||||
|
||||
sigHdlr.handle()
|
||||
os.Exit(sigHdlr.handle(ctx))
|
||||
}
|
||||
|
||||
// defaultTimeout is the timeout used for some operations where another timeout
|
||||
// hasn't been defined yet.
|
||||
const defaultTimeout = 5 * time.Second
|
||||
|
||||
// ctxWithDefaultTimeout is a helper function that returns a context with
|
||||
// timeout set to defaultTimeout.
|
||||
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), defaultTimeout)
|
||||
}
|
||||
// Default timeouts.
|
||||
//
|
||||
// TODO(a.garipov): Make configurable.
|
||||
const (
|
||||
defaultTimeoutStart = 1 * time.Minute
|
||||
defaultTimeoutShutdown = 5 * time.Second
|
||||
)
|
||||
|
||||
// newConfigMgr returns a new configuration manager using defaultTimeout as the
|
||||
// context timeout.
|
||||
func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) {
|
||||
ctx, cancel := ctxWithDefaultTimeout()
|
||||
defer cancel()
|
||||
|
||||
func newConfigMgr(ctx context.Context, c *configmgr.Config) (m *configmgr.Manager, err error) {
|
||||
return configmgr.New(ctx, c)
|
||||
}
|
||||
|
||||
// check is a simple error-checking helper. It must only be used within Main.
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// syslogServiceName is the name of the AdGuard Home service used for writing
|
||||
// logs to the system log.
|
||||
const syslogServiceName = "AdGuardHome"
|
||||
|
||||
// setLog sets up the text logging.
|
||||
//
|
||||
// TODO(a.garipov): Add parameters from configuration file.
|
||||
func setLog(opts *options) (err error) {
|
||||
// newBaseLogger constructs a base logger based on the command-line options.
|
||||
// opts must not be nil.
|
||||
func newBaseLogger(opts *options) (baseLogger *slog.Logger) {
|
||||
var output io.Writer
|
||||
switch opts.confFile {
|
||||
case "stdout":
|
||||
log.SetOutput(os.Stdout)
|
||||
output = os.Stdout
|
||||
case "stderr":
|
||||
log.SetOutput(os.Stderr)
|
||||
output = os.Stderr
|
||||
case "syslog":
|
||||
err = aghos.ConfigureSyslog(syslogServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing syslog: %w", err)
|
||||
}
|
||||
// TODO(a.garipov): Add a syslog handler to golibs.
|
||||
default:
|
||||
// TODO(a.garipov): Use the path.
|
||||
// TODO(a.garipov): Use the path.
|
||||
}
|
||||
|
||||
lvl := slog.LevelInfo
|
||||
if opts.verbose {
|
||||
log.SetLevel(log.DEBUG)
|
||||
log.Debug("verbose logging enabled")
|
||||
lvl = slog.LevelDebug
|
||||
}
|
||||
|
||||
return nil
|
||||
return slogutil.New(&slogutil.Config{
|
||||
Output: output,
|
||||
// TODO(a.garipov): Get from config?
|
||||
Format: slogutil.FormatText,
|
||||
Level: lvl,
|
||||
// TODO(a.garipov): Get from config.
|
||||
AddTimestamp: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -14,7 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
|
||||
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
||||
@@ -87,6 +89,12 @@ type options struct {
|
||||
// TODO(a.garipov): Use.
|
||||
performUpdate bool
|
||||
|
||||
// noPermCheck, if true, instructs AdGuard Home to skip checking and
|
||||
// migrating the permissions of its security-sensitive files.
|
||||
//
|
||||
// TODO(e.burkov): Use.
|
||||
noPermCheck bool
|
||||
|
||||
// verbose, if true, instructs AdGuard Home to enable verbose logging.
|
||||
verbose bool
|
||||
|
||||
@@ -108,7 +116,8 @@ const (
|
||||
disableUpdateIdx
|
||||
glinetModeIdx
|
||||
helpIdx
|
||||
localFrontend
|
||||
localFrontendIdx
|
||||
noPermCheckIdx
|
||||
performUpdateIdx
|
||||
verboseIdx
|
||||
versionIdx
|
||||
@@ -212,7 +221,7 @@ var commandLineOptions = []*commandLineOption{
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
localFrontend: {
|
||||
localFrontendIdx: {
|
||||
defaultValue: false,
|
||||
description: "Use local frontend directories.",
|
||||
long: "local-frontend",
|
||||
@@ -220,6 +229,14 @@ var commandLineOptions = []*commandLineOption{
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
noPermCheckIdx: {
|
||||
defaultValue: false,
|
||||
description: "Skip checking the permissions of security-sensitive files.",
|
||||
long: "no-permcheck",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
performUpdateIdx: {
|
||||
defaultValue: false,
|
||||
description: "Update the current binary and restart the service in case it's installed.",
|
||||
@@ -262,7 +279,8 @@ func parseOptions(cmdName string, args []string) (opts *options, err error) {
|
||||
disableUpdateIdx: &opts.disableUpdate,
|
||||
glinetModeIdx: &opts.glinetMode,
|
||||
helpIdx: &opts.help,
|
||||
localFrontend: &opts.localFrontend,
|
||||
localFrontendIdx: &opts.localFrontend,
|
||||
noPermCheckIdx: &opts.noPermCheck,
|
||||
performUpdateIdx: &opts.performUpdate,
|
||||
verboseIdx: &opts.verbose,
|
||||
versionIdx: &opts.version,
|
||||
@@ -372,13 +390,13 @@ func processOptions(
|
||||
) (exitCode int, needExit bool) {
|
||||
if parseErr != nil {
|
||||
// Assume that usage has already been printed.
|
||||
return statusArgumentError, true
|
||||
return osutil.ExitCodeArgumentError, true
|
||||
}
|
||||
|
||||
if opts.help {
|
||||
usage(cmdName, os.Stdout)
|
||||
|
||||
return statusSuccess, true
|
||||
return osutil.ExitCodeSuccess, true
|
||||
}
|
||||
|
||||
if opts.version {
|
||||
@@ -388,7 +406,7 @@ func processOptions(
|
||||
fmt.Printf("AdGuard Home %s\n", version.Version())
|
||||
}
|
||||
|
||||
return statusSuccess, true
|
||||
return osutil.ExitCodeSuccess, true
|
||||
}
|
||||
|
||||
if opts.checkConfig {
|
||||
@@ -396,21 +414,26 @@ func processOptions(
|
||||
if err != nil {
|
||||
_, _ = io.WriteString(os.Stdout, err.Error()+"\n")
|
||||
|
||||
return statusError, true
|
||||
return osutil.ExitCodeFailure, true
|
||||
}
|
||||
|
||||
return statusSuccess, true
|
||||
return osutil.ExitCodeSuccess, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// frontendFromOpts returns the frontend to use based on the options.
|
||||
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) {
|
||||
func frontendFromOpts(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
opts *options,
|
||||
embeddedFrontend fs.FS,
|
||||
) (frontend fs.FS, err error) {
|
||||
const frontendSubdir = "build/static"
|
||||
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
logger.WarnContext(ctx, "using local frontend files")
|
||||
|
||||
return os.DirFS(frontendSubdir), nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
)
|
||||
|
||||
// signalHandler processes incoming signals and shuts services down.
|
||||
type signalHandler struct {
|
||||
// logger is used for logging the operation of the signal handler.
|
||||
logger *slog.Logger
|
||||
|
||||
// confMgrConf contains the configuration parameters for the configuration
|
||||
// manager.
|
||||
confMgrConf *configmgr.Config
|
||||
@@ -24,145 +32,172 @@ type signalHandler struct {
|
||||
pidFile string
|
||||
|
||||
// services are the services that are shut down before application exiting.
|
||||
services []agh.Service
|
||||
services []service.Interface
|
||||
|
||||
// shutdownTimeout is the timeout for the shutdown operation.
|
||||
shutdownTimeout time.Duration
|
||||
}
|
||||
|
||||
// handle processes OS signals.
|
||||
func (h *signalHandler) handle() {
|
||||
defer log.OnPanic("signalHandler.handle")
|
||||
// handle processes OS signals. It blocks until a termination or a
|
||||
// reconfiguration signal is received, after which it either shuts down all
|
||||
// services or reconfigures them. ctx is used for logging and serves as the
|
||||
// base for the shutdown timeout. status is [osutil.ExitCodeSuccess] on success
|
||||
// and [osutil.ExitCodeFailure] on error.
|
||||
//
|
||||
// TODO(a.garipov): Add reconfiguration logic to golibs.
|
||||
func (h *signalHandler) handle(ctx context.Context) (status osutil.ExitCode) {
|
||||
defer slogutil.RecoverAndLog(ctx, h.logger)
|
||||
|
||||
h.writePID()
|
||||
h.writePID(ctx)
|
||||
|
||||
for sig := range h.signal {
|
||||
log.Info("sighdlr: received signal %q", sig)
|
||||
h.logger.InfoContext(ctx, "received", "signal", sig)
|
||||
|
||||
if aghos.IsReconfigureSignal(sig) {
|
||||
h.reconfigure()
|
||||
if osutil.IsReconfigureSignal(sig) {
|
||||
err := h.reconfigure(ctx)
|
||||
if err != nil {
|
||||
h.logger.ErrorContext(ctx, "reconfiguration error", slogutil.KeyError, err)
|
||||
|
||||
return osutil.ExitCodeFailure
|
||||
}
|
||||
} else if osutil.IsShutdownSignal(sig) {
|
||||
status := h.shutdown()
|
||||
h.removePID()
|
||||
status = h.shutdown(ctx)
|
||||
|
||||
log.Info("sighdlr: exiting with status %d", status)
|
||||
h.removePID(ctx)
|
||||
|
||||
os.Exit(status)
|
||||
return status
|
||||
}
|
||||
}
|
||||
|
||||
// Shouldn't happen, since h.signal is currently never closed.
|
||||
panic("unexpected close of h.signal")
|
||||
}
|
||||
|
||||
// writePID writes the PID to the file, if needed. Any errors are reported to
|
||||
// log.
|
||||
func (h *signalHandler) writePID(ctx context.Context) {
|
||||
if h.pidFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
pid := os.Getpid()
|
||||
data := strconv.AppendInt(nil, int64(pid), 10)
|
||||
data = append(data, '\n')
|
||||
|
||||
err := maybe.WriteFile(h.pidFile, data, 0o644)
|
||||
if err != nil {
|
||||
h.logger.ErrorContext(ctx, "writing pidfile", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.DebugContext(ctx, "wrote pid", "file", h.pidFile, "pid", pid)
|
||||
}
|
||||
|
||||
// reconfigure rereads the configuration file and updates and restarts services.
|
||||
func (h *signalHandler) reconfigure() {
|
||||
log.Info("sighdlr: reconfiguring adguard home")
|
||||
func (h *signalHandler) reconfigure(ctx context.Context) (err error) {
|
||||
h.logger.InfoContext(ctx, "reconfiguring started")
|
||||
|
||||
status := h.shutdown()
|
||||
if status != statusSuccess {
|
||||
log.Info("sighdlr: reconfiguring: exiting with status %d", status)
|
||||
|
||||
os.Exit(status)
|
||||
status := h.shutdown(ctx)
|
||||
if status != osutil.ExitCodeSuccess {
|
||||
return errors.Error("shutdown failed")
|
||||
}
|
||||
|
||||
// TODO(a.garipov): This is a very rough way to do it. Some services can be
|
||||
// reconfigured without the full shutdown, and the error handling is
|
||||
// TODO(a.garipov): This is a very rough way to do it. Some services can
|
||||
// be reconfigured without the full shutdown, and the error handling is
|
||||
// currently not the best.
|
||||
|
||||
confMgr, err := newConfigMgr(h.confMgrConf)
|
||||
check(err)
|
||||
var errs []error
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultTimeoutStart)
|
||||
defer cancel()
|
||||
|
||||
confMgr, err := newConfigMgr(ctx, h.confMgrConf)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("configuration manager: %w", err))
|
||||
}
|
||||
|
||||
web := confMgr.Web()
|
||||
err = web.Start()
|
||||
check(err)
|
||||
err = web.Start(ctx)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("starting web: %w", err))
|
||||
}
|
||||
|
||||
dns := confMgr.DNS()
|
||||
err = dns.Start()
|
||||
check(err)
|
||||
err = dns.Start(ctx)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("starting dns: %w", err))
|
||||
}
|
||||
|
||||
h.services = []agh.Service{
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
h.services = []service.Interface{
|
||||
dns,
|
||||
web,
|
||||
}
|
||||
|
||||
log.Info("sighdlr: successfully reconfigured adguard home")
|
||||
h.logger.InfoContext(ctx, "reconfiguring finished")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exit status constants.
|
||||
const (
|
||||
statusSuccess = 0
|
||||
statusError = 1
|
||||
statusArgumentError = 2
|
||||
)
|
||||
|
||||
// shutdown gracefully shuts down all services.
|
||||
func (h *signalHandler) shutdown() (status int) {
|
||||
ctx, cancel := ctxWithDefaultTimeout()
|
||||
func (h *signalHandler) shutdown(ctx context.Context) (status int) {
|
||||
ctx, cancel := context.WithTimeout(ctx, h.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
status = statusSuccess
|
||||
status = osutil.ExitCodeSuccess
|
||||
|
||||
log.Info("sighdlr: shutting down services")
|
||||
for i, service := range h.services {
|
||||
err := service.Shutdown(ctx)
|
||||
h.logger.InfoContext(ctx, "shutting down")
|
||||
for i, svc := range h.services {
|
||||
err := svc.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
|
||||
status = statusError
|
||||
h.logger.ErrorContext(ctx, "shutting down service", "idx", i, slogutil.KeyError, err)
|
||||
status = osutil.ExitCodeFailure
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// newSignalHandler returns a new signalHandler that shuts down svcs.
|
||||
// newSignalHandler returns a new signalHandler that shuts down svcs. logger
|
||||
// and confMgrConf must not be nil.
|
||||
func newSignalHandler(
|
||||
logger *slog.Logger,
|
||||
confMgrConf *configmgr.Config,
|
||||
pidFile string,
|
||||
svcs ...agh.Service,
|
||||
svcs ...service.Interface,
|
||||
) (h *signalHandler) {
|
||||
h = &signalHandler{
|
||||
confMgrConf: confMgrConf,
|
||||
signal: make(chan os.Signal, 1),
|
||||
pidFile: pidFile,
|
||||
services: svcs,
|
||||
logger: logger,
|
||||
confMgrConf: confMgrConf,
|
||||
signal: make(chan os.Signal, 1),
|
||||
pidFile: pidFile,
|
||||
services: svcs,
|
||||
shutdownTimeout: defaultTimeoutShutdown,
|
||||
}
|
||||
|
||||
notifier := osutil.DefaultSignalNotifier{}
|
||||
osutil.NotifyShutdownSignal(notifier, h.signal)
|
||||
aghos.NotifyReconfigureSignal(h.signal)
|
||||
osutil.NotifyReconfigureSignal(notifier, h.signal)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// writePID writes the PID to the file, if needed. Any errors are reported to
|
||||
// log.
|
||||
func (h *signalHandler) writePID() {
|
||||
if h.pidFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Use 8, since most PIDs will fit.
|
||||
data := make([]byte, 0, 8)
|
||||
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
|
||||
data = append(data, '\n')
|
||||
|
||||
err := aghos.WriteFile(h.pidFile, data, 0o644)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: writing pidfile: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
|
||||
}
|
||||
|
||||
// removePID removes the PID file, if any.
|
||||
func (h *signalHandler) removePID() {
|
||||
func (h *signalHandler) removePID(ctx context.Context) {
|
||||
if h.pidFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := os.Remove(h.pidFile)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: removing pidfile: %s", err)
|
||||
h.logger.ErrorContext(ctx, "removing pidfile", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("sighdlr: removed pid at %q", h.pidFile)
|
||||
h.logger.DebugContext(ctx, "removed pidfile", "file", h.pidFile)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,11 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// Configuration Structures
|
||||
|
||||
// config is the top-level on-disk configuration structure.
|
||||
type config struct {
|
||||
DNS *dnsConfig `yaml:"dns"`
|
||||
@@ -19,35 +18,33 @@ type config struct {
|
||||
SchemaVersion int `yaml:"schema_version"`
|
||||
}
|
||||
|
||||
const errNoConf errors.Error = "configuration not found"
|
||||
// type check
|
||||
var _ validator = (*config)(nil)
|
||||
|
||||
// validate returns an error if the configuration structure is invalid.
|
||||
// validate implements the [validator] interface for *config.
|
||||
func (c *config) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Add more validations.
|
||||
|
||||
// Keep this in the same order as the fields in the config.
|
||||
validators := []struct {
|
||||
validate func() (err error)
|
||||
name string
|
||||
}{{
|
||||
validate: c.DNS.validate,
|
||||
name: "dns",
|
||||
validators := container.KeyValues[string, validator]{{
|
||||
Key: "dns",
|
||||
Value: c.DNS,
|
||||
}, {
|
||||
validate: c.HTTP.validate,
|
||||
name: "http",
|
||||
Key: "http",
|
||||
Value: c.HTTP,
|
||||
}, {
|
||||
validate: c.Log.validate,
|
||||
name: "log",
|
||||
Key: "log",
|
||||
Value: c.Log,
|
||||
}}
|
||||
|
||||
for _, v := range validators {
|
||||
err = v.validate()
|
||||
for _, kv := range validators {
|
||||
err = kv.Value.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", v.name, err)
|
||||
return fmt.Errorf("%s: %w", kv.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,16 +62,19 @@ type dnsConfig struct {
|
||||
UseDNS64 bool `yaml:"use_dns64"`
|
||||
}
|
||||
|
||||
// validate returns an error if the DNS configuration structure is invalid.
|
||||
// type check
|
||||
var _ validator = (*dnsConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *dnsConfig.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *dnsConfig) validate() (err error) {
|
||||
// TODO(a.garipov): Add more validations.
|
||||
switch {
|
||||
case c == nil:
|
||||
return errNoConf
|
||||
return errors.ErrNoValue
|
||||
case c.UpstreamTimeout.Duration <= 0:
|
||||
return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout)
|
||||
return newErrNotPositive("upstream_timeout", c.UpstreamTimeout)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -91,15 +91,18 @@ type httpConfig struct {
|
||||
ForceHTTPS bool `yaml:"force_https"`
|
||||
}
|
||||
|
||||
// validate returns an error if the HTTP configuration structure is invalid.
|
||||
// type check
|
||||
var _ validator = (*httpConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *httpConfig.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *httpConfig) validate() (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
return errNoConf
|
||||
return errors.ErrNoValue
|
||||
case c.Timeout.Duration <= 0:
|
||||
return newMustBePositiveError("timeout", c.Timeout)
|
||||
return newErrNotPositive("timeout", c.Timeout)
|
||||
default:
|
||||
return c.Pprof.validate()
|
||||
}
|
||||
@@ -111,10 +114,13 @@ type httpPprofConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// validate returns an error if the pprof configuration structure is invalid.
|
||||
// type check
|
||||
var _ validator = (*httpPprofConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *httpPprofConfig.
|
||||
func (c *httpPprofConfig) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -126,12 +132,15 @@ type logConfig struct {
|
||||
Verbose bool `yaml:"verbose"`
|
||||
}
|
||||
|
||||
// validate returns an error if the HTTP configuration structure is invalid.
|
||||
// type check
|
||||
var _ validator = (*logConfig)(nil)
|
||||
|
||||
// validate implements the [validator] interface for *logConfig.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *logConfig) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
return errors.ErrNoValue
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -19,18 +20,23 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Configuration Manager
|
||||
|
||||
// Manager handles full and partial changes in the configuration, persisting
|
||||
// them to disk if necessary.
|
||||
//
|
||||
// TODO(a.garipov): Support missing configs and default values.
|
||||
type Manager struct {
|
||||
// baseLogger is used to create loggers for other entities.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// logger is used for logging the operation of the configuration manager.
|
||||
logger *slog.Logger
|
||||
|
||||
// updMu makes sure that at most one reconfiguration is performed at a time.
|
||||
// updMu protects all fields below.
|
||||
updMu *sync.RWMutex
|
||||
@@ -57,12 +63,24 @@ func Validate(fileName string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return conf.validate()
|
||||
err = conf.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config contains the configuration parameters for the configuration manager.
|
||||
type Config struct {
|
||||
// BaseLogger is used to create loggers for other entities. It must not be
|
||||
// nil.
|
||||
BaseLogger *slog.Logger
|
||||
|
||||
// Logger is used for logging the operation of the configuration manager.
|
||||
// It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Frontend is the filesystem with the frontend files.
|
||||
Frontend fs.FS
|
||||
|
||||
@@ -93,9 +111,11 @@ func New(ctx context.Context, c *Config) (m *Manager, err error) {
|
||||
}
|
||||
|
||||
m = &Manager{
|
||||
updMu: &sync.RWMutex{},
|
||||
current: conf,
|
||||
fileName: c.FileName,
|
||||
baseLogger: c.BaseLogger,
|
||||
logger: c.Logger,
|
||||
updMu: &sync.RWMutex{},
|
||||
current: conf,
|
||||
fileName: c.FileName,
|
||||
}
|
||||
|
||||
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
|
||||
@@ -137,6 +157,7 @@ func (m *Manager) assemble(
|
||||
start time.Time,
|
||||
) (err error) {
|
||||
dnsConf := &dnssvc.Config{
|
||||
Logger: m.baseLogger.With(slogutil.KeyPrefix, "dnssvc"),
|
||||
Addresses: conf.DNS.Addresses,
|
||||
BootstrapServers: conf.DNS.BootstrapDNS,
|
||||
UpstreamServers: conf.DNS.UpstreamDNS,
|
||||
@@ -151,6 +172,7 @@ func (m *Manager) assemble(
|
||||
}
|
||||
|
||||
webSvcConf := &websvc.Config{
|
||||
Logger: m.baseLogger.With(slogutil.KeyPrefix, "websvc"),
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Port: conf.HTTP.Pprof.Port,
|
||||
Enabled: conf.HTTP.Pprof.Enabled,
|
||||
@@ -176,18 +198,18 @@ func (m *Manager) assemble(
|
||||
}
|
||||
|
||||
// write writes the current configuration to disk.
|
||||
func (m *Manager) write() (err error) {
|
||||
func (m *Manager) write(ctx context.Context) (err error) {
|
||||
b, err := yaml.Marshal(m.current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding: %w", err)
|
||||
}
|
||||
|
||||
err = aghos.WriteFile(m.fileName, b, aghos.DefaultPermFile)
|
||||
err = maybe.WriteFile(m.fileName, b, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing: %w", err)
|
||||
}
|
||||
|
||||
log.Info("configmgr: written to %q", m.fileName)
|
||||
m.logger.InfoContext(ctx, "config file written", "path", m.fileName)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -216,7 +238,7 @@ func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
|
||||
m.updateCurrentDNS(c)
|
||||
|
||||
return m.write()
|
||||
return m.write(ctx)
|
||||
}
|
||||
|
||||
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
|
||||
@@ -270,7 +292,7 @@ func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||
|
||||
m.updateCurrentWeb(c)
|
||||
|
||||
return m.write()
|
||||
return m.write(ctx)
|
||||
}
|
||||
|
||||
// updateWeb recreates the web service. m.upd is expected to be locked.
|
||||
|
||||
@@ -3,25 +3,29 @@ package configmgr
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// validator is the interface for configuration entities that can validate
|
||||
// themselves.
|
||||
type validator interface {
|
||||
// validate returns an error if the entity isn't valid.
|
||||
validate() (err error)
|
||||
}
|
||||
|
||||
// numberOrDuration is the constraint for integer types along with
|
||||
// timeutil.Duration.
|
||||
type numberOrDuration interface {
|
||||
constraints.Integer | timeutil.Duration
|
||||
}
|
||||
|
||||
// newMustBePositiveError returns an error about the value that must be positive
|
||||
// but isn't. prop is the name of the property to mention in the error message.
|
||||
// newErrNotPositive returns an error about the value that must be positive but
|
||||
// isn't. prop is the name of the property to mention in the error message.
|
||||
//
|
||||
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
|
||||
// as well.
|
||||
func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) {
|
||||
if s, ok := any(v).(fmt.Stringer); ok {
|
||||
return fmt.Errorf("%s must be positive, got %s", prop, s)
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s must be positive, got %d", prop, v)
|
||||
func newErrNotPositive[T numberOrDuration](prop string, v T) (err error) {
|
||||
return fmt.Errorf("%s: %w, got %v", prop, errors.ErrNotPositive, v)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
@@ -9,6 +10,10 @@ import (
|
||||
//
|
||||
// TODO(a.garipov): Add timeout for incoming requests.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the web API service. It must
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Addresses are the addresses on which to serve plain DNS queries.
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ package dnssvc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
|
||||
// fields that are only used in [New] and [Service.Config].
|
||||
type Service struct {
|
||||
logger *slog.Logger
|
||||
proxy *proxy.Proxy
|
||||
bootstraps []string
|
||||
bootstrapResolvers []*upstream.UpstreamResolver
|
||||
@@ -48,6 +50,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
logger: c.Logger,
|
||||
bootstraps: c.BootstrapServers,
|
||||
upstreams: c.UpstreamServers,
|
||||
dns64Prefixes: c.DNS64Prefixes,
|
||||
@@ -68,6 +71,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||
|
||||
svc.bootstrapResolvers = resolvers
|
||||
svc.proxy, err = proxy.New(&proxy.Config{
|
||||
Logger: svc.logger,
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
@@ -153,12 +157,12 @@ func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all DNS servers have tried to start, but there is no
|
||||
// guarantee that they did. Errors from the servers are written to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
func (svc *Service) Start(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -170,7 +174,7 @@ func (svc *Service) Start() (err error) {
|
||||
svc.running.Store(err == nil)
|
||||
}()
|
||||
|
||||
return svc.proxy.Start(context.Background())
|
||||
return svc.proxy.Start(ctx)
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
@@ -215,6 +219,7 @@ func (svc *Service) Config() (c *Config) {
|
||||
}
|
||||
|
||||
c = &Config{
|
||||
Logger: svc.logger,
|
||||
Addresses: addrs,
|
||||
BootstrapServers: svc.bootstraps,
|
||||
UpstreamServers: svc.upstreams,
|
||||
|
||||
@@ -6,16 +6,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
@@ -59,6 +56,7 @@ func TestService(t *testing.T) {
|
||||
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
|
||||
|
||||
c := &dnssvc.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
|
||||
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
|
||||
UpstreamServers: []string{upstreamAddr},
|
||||
@@ -71,7 +69,7 @@ func TestService(t *testing.T) {
|
||||
svc, err := dnssvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
gotConf := svc.Config()
|
||||
|
||||
@@ -3,12 +3,17 @@ package websvc
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home web service configuration structure.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the web API service. It must
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Pprof is the configuration for the pprof debug API. It must not be nil.
|
||||
Pprof *PprofConfig
|
||||
|
||||
@@ -60,17 +65,20 @@ type PprofConfig struct {
|
||||
// finished.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
c = &Config{
|
||||
Logger: svc.logger,
|
||||
Pprof: &PprofConfig{
|
||||
Port: svc.pprofPort,
|
||||
Enabled: svc.pprof != nil,
|
||||
},
|
||||
ConfigManager: svc.confMgr,
|
||||
Frontend: svc.frontend,
|
||||
TLS: svc.tls,
|
||||
// Leave Addresses and SecureAddresses empty and get the actual
|
||||
// addresses that include the :0 ones later.
|
||||
Start: svc.start,
|
||||
Timeout: svc.timeout,
|
||||
ForceHTTPS: svc.forceHTTPS,
|
||||
Start: svc.start,
|
||||
OverrideAddress: svc.overrideAddr,
|
||||
Timeout: svc.timeout,
|
||||
ForceHTTPS: svc.forceHTTPS,
|
||||
}
|
||||
|
||||
c.Addresses, c.SecureAddresses = svc.addrs()
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
)
|
||||
|
||||
// DNS Settings Handlers
|
||||
|
||||
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
|
||||
// HTTP API.
|
||||
type ReqPatchSettingsDNS struct {
|
||||
@@ -60,6 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
newConf := &dnssvc.Config{
|
||||
Logger: svc.logger,
|
||||
Addresses: req.Addresses,
|
||||
BootstrapServers: req.BootstrapServers,
|
||||
UpstreamServers: req.UpstreamServers,
|
||||
@@ -78,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
newSvc := svc.confMgr.DNS()
|
||||
err = newSvc.Start()
|
||||
err = newSvc.Start(ctx)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -34,7 +35,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
|
||||
OnStart: func() (err error) {
|
||||
OnStart: func(_ context.Context) (err error) {
|
||||
started.Store(true)
|
||||
|
||||
return nil
|
||||
@@ -49,9 +50,9 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsDNS,
|
||||
Path: websvc.PathPatternV1SettingsDNS,
|
||||
}
|
||||
|
||||
req := jobj{
|
||||
|
||||
@@ -10,11 +10,9 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// HTTP Settings Handlers
|
||||
|
||||
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
|
||||
// HTTP API.
|
||||
type ReqPatchSettingsHTTP struct {
|
||||
@@ -53,6 +51,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
newConf := &Config{
|
||||
Logger: svc.logger,
|
||||
Pprof: &PprofConfig{
|
||||
Port: svc.pprofPort,
|
||||
Enabled: svc.pprof != nil,
|
||||
@@ -89,13 +88,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||
// relaunch updates the web service in the configuration manager and starts it.
|
||||
// It is intended to be used as a goroutine.
|
||||
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
|
||||
defer log.OnPanic("websvc: relaunching")
|
||||
defer slogutil.RecoverAndLog(ctx, svc.logger)
|
||||
|
||||
defer cancel()
|
||||
|
||||
err := svc.confMgr.UpdateWeb(ctx, newConf)
|
||||
if err != nil {
|
||||
log.Error("websvc: updating web: %s", err)
|
||||
svc.logger.ErrorContext(ctx, "updating web", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -106,18 +105,18 @@ func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, new
|
||||
var newSvc agh.ServiceWithConfig[*Config]
|
||||
for newSvc = svc.confMgr.Web(); newSvc == svc; {
|
||||
if time.Since(updStart) >= maxUpdDur {
|
||||
log.Error("websvc: failed to update svc after %s", maxUpdDur)
|
||||
svc.logger.ErrorContext(ctx, "failed to update service on time", "duration", maxUpdDur)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("websvc: waiting for new websvc to be configured")
|
||||
svc.logger.DebugContext(ctx, "waiting for new service")
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
err = newSvc.Start()
|
||||
err = newSvc.Start(ctx)
|
||||
if err != nil {
|
||||
log.Error("websvc: new svc failed to start with error: %s", err)
|
||||
svc.logger.ErrorContext(ctx, "new service failed", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -26,14 +28,15 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
|
||||
}
|
||||
|
||||
svc, err := websvc.New(&websvc.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
TLS: &tls.Config{
|
||||
Certificates: []tls.Certificate{{}},
|
||||
},
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
Timeout: 5 * time.Second,
|
||||
ForceHTTPS: true,
|
||||
})
|
||||
@@ -45,9 +48,9 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsHTTP,
|
||||
Path: websvc.PathPatternV1SettingsHTTP,
|
||||
}
|
||||
|
||||
req := jobj{
|
||||
|
||||
@@ -2,15 +2,11 @@ package websvc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Middlewares
|
||||
|
||||
// jsonMw sets the content type of the response to application/json.
|
||||
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -21,18 +17,3 @@ func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// logMw logs the queries with level debug.
|
||||
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
m, u := r.Method, r.RequestURI
|
||||
|
||||
log.Debug("websvc: %s %s started", m, u)
|
||||
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
|
||||
73
internal/next/websvc/route.go
Normal file
73
internal/next/websvc/route.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
)
|
||||
|
||||
// Path pattern constants.
|
||||
const (
|
||||
PathPatternFrontend = "/"
|
||||
PathPatternHealthCheck = "/health-check"
|
||||
PathPatternV1SettingsAll = "/api/v1/settings/all"
|
||||
PathPatternV1SettingsDNS = "/api/v1/settings/dns"
|
||||
PathPatternV1SettingsHTTP = "/api/v1/settings/http"
|
||||
PathPatternV1SystemInfo = "/api/v1/system/info"
|
||||
)
|
||||
|
||||
// Route pattern constants.
|
||||
const (
|
||||
routePatternFrontend = http.MethodGet + " " + PathPatternFrontend
|
||||
routePatternGetV1SettingsAll = http.MethodGet + " " + PathPatternV1SettingsAll
|
||||
routePatternGetV1SystemInfo = http.MethodGet + " " + PathPatternV1SystemInfo
|
||||
routePatternHealthCheck = http.MethodGet + " " + PathPatternHealthCheck
|
||||
routePatternPatchV1SettingsDNS = http.MethodPatch + " " + PathPatternV1SettingsDNS
|
||||
routePatternPatchV1SettingsHTTP = http.MethodPatch + " " + PathPatternV1SettingsHTTP
|
||||
)
|
||||
|
||||
// route registers all necessary handlers in mux.
|
||||
func (svc *Service) route(mux *http.ServeMux) {
|
||||
routes := []struct {
|
||||
handler http.Handler
|
||||
pattern string
|
||||
isJSON bool
|
||||
}{{
|
||||
handler: httputil.HealthCheckHandler,
|
||||
pattern: routePatternHealthCheck,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.FileServer(http.FS(svc.frontend)),
|
||||
pattern: routePatternFrontend,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.HandlerFunc(svc.handleGetSettingsAll),
|
||||
pattern: routePatternGetV1SettingsAll,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: http.HandlerFunc(svc.handlePatchSettingsDNS),
|
||||
pattern: routePatternPatchV1SettingsDNS,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: http.HandlerFunc(svc.handlePatchSettingsHTTP),
|
||||
pattern: routePatternPatchV1SettingsHTTP,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: http.HandlerFunc(svc.handleGetV1SystemInfo),
|
||||
pattern: routePatternGetV1SystemInfo,
|
||||
isJSON: true,
|
||||
}}
|
||||
|
||||
logMw := httputil.NewLogMiddleware(svc.logger, slog.LevelDebug)
|
||||
for _, r := range routes {
|
||||
var hdlr http.Handler
|
||||
if r.isJSON {
|
||||
hdlr = jsonMw(r.handler)
|
||||
} else {
|
||||
hdlr = r.handler
|
||||
}
|
||||
|
||||
mux.Handle(r.pattern, logMw.Wrap(hdlr))
|
||||
}
|
||||
}
|
||||
156
internal/next/websvc/server.go
Normal file
156
internal/next/websvc/server.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// server contains an *http.Server as well as entities and data associated with
|
||||
// it.
|
||||
//
|
||||
// TODO(a.garipov): Join with similar structs in other projects and move to
|
||||
// golibs/netutil/httputil.
|
||||
//
|
||||
// TODO(a.garipov): Once the above standardization is complete, consider
|
||||
// merging debugsvc and websvc into a single httpsvc.
|
||||
type server struct {
|
||||
// mu protects http, logger, tcpListener, and url.
|
||||
mu *sync.Mutex
|
||||
http *http.Server
|
||||
logger *slog.Logger
|
||||
tcpListener *net.TCPListener
|
||||
url *url.URL
|
||||
|
||||
tlsConf *tls.Config
|
||||
initialAddr netip.AddrPort
|
||||
}
|
||||
|
||||
// loggerKeyServer is the key used by [server] to identify itself.
|
||||
const loggerKeyServer = "server"
|
||||
|
||||
// newServer returns a *server that is ready to serve HTTP queries. The TCP
|
||||
// listener is not started. handler must not be nil.
|
||||
func newServer(
|
||||
baseLogger *slog.Logger,
|
||||
initialAddr netip.AddrPort,
|
||||
tlsConf *tls.Config,
|
||||
handler http.Handler,
|
||||
timeout time.Duration,
|
||||
) (s *server) {
|
||||
u := &url.URL{
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: initialAddr.String(),
|
||||
}
|
||||
|
||||
if tlsConf != nil {
|
||||
u.Scheme = urlutil.SchemeHTTPS
|
||||
}
|
||||
|
||||
logger := baseLogger.With(loggerKeyServer, u)
|
||||
|
||||
return &server{
|
||||
mu: &sync.Mutex{},
|
||||
http: &http.Server{
|
||||
Handler: handler,
|
||||
ReadTimeout: timeout,
|
||||
ReadHeaderTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
IdleTimeout: timeout,
|
||||
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||
},
|
||||
logger: logger,
|
||||
url: u,
|
||||
|
||||
tlsConf: tlsConf,
|
||||
initialAddr: initialAddr,
|
||||
}
|
||||
}
|
||||
|
||||
// localAddr returns the local address of the server if the server has started
|
||||
// listening; otherwise, it returns nil.
|
||||
func (s *server) localAddr() (addr net.Addr) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if l := s.tcpListener; l != nil {
|
||||
return l.Addr()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serve starts s. baseLogger is used as a base logger for s. If s fails to
|
||||
// serve with anything other than [http.ErrServerClosed], it causes an unhandled
|
||||
// panic. It is intended to be used as a goroutine.
|
||||
//
|
||||
// TODO(a.garipov): Improve error handling.
|
||||
func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) {
|
||||
l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr))
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err)
|
||||
|
||||
panic(fmt.Errorf("websvc: listening tcp: %w", err))
|
||||
}
|
||||
|
||||
func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tcpListener = l
|
||||
|
||||
// Reassign the address in case the port was zero.
|
||||
s.url.Host = l.Addr().String()
|
||||
s.logger = baseLogger.With(loggerKeyServer, s.url)
|
||||
s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)
|
||||
}()
|
||||
|
||||
s.logger.InfoContext(ctx, "starting")
|
||||
defer s.logger.InfoContext(ctx, "started")
|
||||
|
||||
err = s.http.Serve(l)
|
||||
if err == nil || errors.Is(err, http.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err)
|
||||
|
||||
panic(fmt.Errorf("websvc: serving: %w", err))
|
||||
}
|
||||
|
||||
// shutdown shuts s down.
|
||||
func (s *server) shutdown(ctx context.Context) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
err = s.http.Shutdown(ctx)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err))
|
||||
}
|
||||
|
||||
// Close the listener separately, as it might not have been closed if the
|
||||
// context has been canceled.
|
||||
//
|
||||
// NOTE: The listener could remain uninitialized if [net.ListenTCP] failed
|
||||
// in [s.serve].
|
||||
if l := s.tcpListener; l != nil {
|
||||
err = l.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -28,16 +29,10 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||
BootstrapPreferIPv6: true,
|
||||
}
|
||||
|
||||
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
||||
Timeout: aghhttp.JSONDuration(5 * time.Second),
|
||||
ForceHTTPS: true,
|
||||
}
|
||||
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
c, err := dnssvc.New(&dnssvc.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Addresses: wantDNS.Addresses,
|
||||
UpstreamServers: wantDNS.UpstreamServers,
|
||||
BootstrapServers: wantDNS.BootstrapServers,
|
||||
@@ -49,34 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||
return c
|
||||
}
|
||||
|
||||
svc, err := websvc.New(&websvc.Config{
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
TLS: &tls.Config{
|
||||
Certificates: []tls.Certificate{{}},
|
||||
},
|
||||
Addresses: wantWeb.Addresses,
|
||||
SecureAddresses: wantWeb.SecureAddresses,
|
||||
Timeout: time.Duration(wantWeb.Timeout),
|
||||
ForceHTTPS: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
svc, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathPatternV1SettingsAll,
|
||||
}
|
||||
|
||||
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
|
||||
return svc
|
||||
}
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsAll,
|
||||
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
||||
Addresses: []netip.AddrPort{addr},
|
||||
SecureAddresses: nil,
|
||||
Timeout: aghhttp.JSONDuration(testTimeout),
|
||||
ForceHTTPS: false,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
resp := &websvc.RespGetV1SettingsAll{}
|
||||
err = json.Unmarshal(body, resp)
|
||||
err := json.Unmarshal(body, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantDNS, resp.DNS)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -17,9 +18,9 @@ func TestService_handleGetV1SystemInfo(t *testing.T) {
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SystemInfo,
|
||||
Path: websvc.PathPatternV1SystemInfo,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
|
||||
@@ -10,22 +10,18 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
httptreemux "github.com/dimfeld/httptreemux/v5"
|
||||
)
|
||||
|
||||
// ConfigManager is the configuration manager interface.
|
||||
@@ -40,13 +36,14 @@ type ConfigManager interface {
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
logger *slog.Logger
|
||||
confMgr ConfigManager
|
||||
frontend fs.FS
|
||||
tls *tls.Config
|
||||
pprof *http.Server
|
||||
pprof *server
|
||||
start time.Time
|
||||
overrideAddr netip.AddrPort
|
||||
servers []*http.Server
|
||||
servers []*server
|
||||
timeout time.Duration
|
||||
pprofPort uint16
|
||||
forceHTTPS bool
|
||||
@@ -64,6 +61,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
logger: c.Logger,
|
||||
confMgr: c.ConfigManager,
|
||||
frontend: c.Frontend,
|
||||
tls: c.TLS,
|
||||
@@ -73,17 +71,18 @@ func New(c *Config) (svc *Service, err error) {
|
||||
forceHTTPS: c.ForceHTTPS,
|
||||
}
|
||||
|
||||
mux := newMux(svc)
|
||||
mux := http.NewServeMux()
|
||||
svc.route(mux)
|
||||
|
||||
if svc.overrideAddr != (netip.AddrPort{}) {
|
||||
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
|
||||
svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)}
|
||||
} else {
|
||||
for _, a := range c.Addresses {
|
||||
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
|
||||
svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout))
|
||||
}
|
||||
|
||||
for _, a := range c.SecureAddresses {
|
||||
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
|
||||
svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,96 +111,7 @@ func (svc *Service) setupPprof(c *PprofConfig) {
|
||||
svc.pprofPort = c.Port
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
|
||||
|
||||
// TODO(a.garipov): Consider making pprof timeout configurable.
|
||||
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
|
||||
}
|
||||
|
||||
// newSrv returns a new *http.Server with the given parameters.
|
||||
func newSrv(
|
||||
addr netip.AddrPort,
|
||||
tlsConf *tls.Config,
|
||||
h http.Handler,
|
||||
timeout time.Duration,
|
||||
) (srv *http.Server) {
|
||||
addrStr := addr.String()
|
||||
srv = &http.Server{
|
||||
Addr: addrStr,
|
||||
Handler: h,
|
||||
TLSConfig: tlsConf,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
IdleTimeout: timeout,
|
||||
ReadHeaderTimeout: timeout,
|
||||
}
|
||||
|
||||
if tlsConf == nil {
|
||||
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
|
||||
} else {
|
||||
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
|
||||
// service.
|
||||
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
|
||||
mux = httptreemux.NewContextMux()
|
||||
|
||||
routes := []struct {
|
||||
handler http.HandlerFunc
|
||||
method string
|
||||
pattern string
|
||||
isJSON bool
|
||||
}{{
|
||||
handler: svc.handleGetHealthCheck,
|
||||
method: http.MethodGet,
|
||||
pattern: PathHealthCheck,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
||||
method: http.MethodGet,
|
||||
pattern: PathFrontend,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
||||
method: http.MethodGet,
|
||||
pattern: PathRoot,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: svc.handleGetSettingsAll,
|
||||
method: http.MethodGet,
|
||||
pattern: PathV1SettingsAll,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handlePatchSettingsDNS,
|
||||
method: http.MethodPatch,
|
||||
pattern: PathV1SettingsDNS,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handlePatchSettingsHTTP,
|
||||
method: http.MethodPatch,
|
||||
pattern: PathV1SettingsHTTP,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handleGetV1SystemInfo,
|
||||
method: http.MethodGet,
|
||||
pattern: PathV1SystemInfo,
|
||||
isJSON: true,
|
||||
}}
|
||||
|
||||
for _, r := range routes {
|
||||
var hdlr http.Handler
|
||||
if r.isJSON {
|
||||
hdlr = jsonMw(r.handler)
|
||||
} else {
|
||||
hdlr = r.handler
|
||||
}
|
||||
|
||||
mux.Handle(r.method, r.pattern, logMw(hdlr))
|
||||
}
|
||||
|
||||
return mux
|
||||
svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
|
||||
}
|
||||
|
||||
// addrs returns all addresses on which this server serves the HTTP API. addrs
|
||||
@@ -214,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||
}
|
||||
|
||||
for _, srv := range svc.servers {
|
||||
// Use MustParseAddrPort, since no errors should technically happen
|
||||
// here, because all servers must have a valid address.
|
||||
addrPort := netip.MustParseAddrPort(srv.Addr)
|
||||
addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
|
||||
if addrPort == (netip.AddrPort{}) {
|
||||
continue
|
||||
}
|
||||
|
||||
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
|
||||
// of relying only on the nilness of TLSConfig, check the length of the
|
||||
// certificates field as well.
|
||||
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
|
||||
if srv.tlsConf == nil {
|
||||
addrs = append(addrs, addrPort)
|
||||
} else {
|
||||
secureAddrs = append(secureAddrs, addrPort)
|
||||
@@ -231,74 +139,60 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||
return addrs, secureAddrs
|
||||
}
|
||||
|
||||
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
|
||||
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "OK")
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
||||
// writing error messages to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
//
|
||||
// TODO(a.garipov): Use the context for cancelation as well.
|
||||
func (svc *Service) Start(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pprofEnabled := svc.pprof != nil
|
||||
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
|
||||
svc.logger.InfoContext(ctx, "starting")
|
||||
defer svc.logger.InfoContext(ctx, "started")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(srvNum)
|
||||
for _, srv := range svc.servers {
|
||||
go serve(srv, wg)
|
||||
go srv.serve(ctx, svc.logger)
|
||||
}
|
||||
|
||||
if pprofEnabled {
|
||||
go serve(svc.pprof, wg)
|
||||
if svc.pprof != nil {
|
||||
go svc.pprof.serve(ctx, svc.logger)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return svc.wait(ctx)
|
||||
}
|
||||
|
||||
// wait waits until either the context is canceled or all servers have started.
|
||||
func (svc *Service) wait(ctx context.Context) (err error) {
|
||||
for !svc.serversHaveStarted() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// Wait and let the other goroutines do their job.
|
||||
runtime.Gosched()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serve starts and runs srv and writes all errors into its log.
|
||||
func serve(srv *http.Server, wg *sync.WaitGroup) {
|
||||
addr := srv.Addr
|
||||
defer log.OnPanic(addr)
|
||||
|
||||
var proto string
|
||||
var l net.Listener
|
||||
var err error
|
||||
if srv.TLSConfig == nil {
|
||||
proto = "http"
|
||||
l, err = net.Listen("tcp", addr)
|
||||
} else {
|
||||
proto = "https"
|
||||
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||
}
|
||||
if err != nil {
|
||||
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
|
||||
// serversHaveStarted returns true if all servers have started serving.
|
||||
func (svc *Service) serversHaveStarted() (started bool) {
|
||||
started = len(svc.servers) != 0
|
||||
for _, srv := range svc.servers {
|
||||
started = started && srv.localAddr() != nil
|
||||
}
|
||||
|
||||
// Update the server's address in case the address had the port zero, which
|
||||
// would mean that a random available port was automatically chosen.
|
||||
srv.Addr = l.Addr().String()
|
||||
|
||||
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
|
||||
|
||||
l = &waitListener{
|
||||
Listener: l,
|
||||
firstAcceptWG: wg,
|
||||
if svc.pprof != nil {
|
||||
started = started && svc.pprof.localAddr() != nil
|
||||
}
|
||||
|
||||
err = srv.Serve(l)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
|
||||
}
|
||||
return started
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
@@ -308,20 +202,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
svc.logger.InfoContext(ctx, "shutting down")
|
||||
defer svc.logger.InfoContext(ctx, "shut down")
|
||||
|
||||
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
|
||||
|
||||
var errs []error
|
||||
for _, srv := range svc.servers {
|
||||
shutdownErr := srv.Shutdown(ctx)
|
||||
shutdownErr := srv.shutdown(ctx)
|
||||
if shutdownErr != nil {
|
||||
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if svc.pprof != nil {
|
||||
shutdownErr := svc.pprof.Shutdown(ctx)
|
||||
shutdownErr := svc.pprof.shutdown(ctx)
|
||||
if shutdownErr != nil {
|
||||
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
|
||||
errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,16 +15,15 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
@@ -80,8 +79,6 @@ func newConfigManager() (m *configManager) {
|
||||
// newTestServer creates and starts a new web service instance as well as its
|
||||
// sole address. It also registers a cleanup procedure, which shuts the
|
||||
// instance down.
|
||||
//
|
||||
// TODO(a.garipov): Use svc or remove it.
|
||||
func newTestServer(
|
||||
t testing.TB,
|
||||
confMgr websvc.ConfigManager,
|
||||
@@ -89,6 +86,7 @@ func newTestServer(
|
||||
t.Helper()
|
||||
|
||||
c := &websvc.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
@@ -107,7 +105,7 @@ func newTestServer(
|
||||
svc, err := websvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||
@@ -181,12 +179,12 @@ func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Scheme: urlutil.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathHealthCheck,
|
||||
Path: websvc.PathPatternHealthCheck,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
|
||||
assert.Equal(t, []byte("OK"), body)
|
||||
assert.Equal(t, []byte(httputil.HealthCheckHandler), body)
|
||||
}
|
||||
|
||||
43
internal/permcheck/check_unix.go
Normal file
43
internal/permcheck/check_unix.go
Normal file
@@ -0,0 +1,43 @@
|
||||
//go:build unix
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
// check is the Unix-specific implementation of [Check].
|
||||
func check(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
workDir string,
|
||||
dataDir string,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
confFilePath string,
|
||||
) {
|
||||
dirLoggger, fileLogger := l.With("type", typeDir), l.With("type", typeFile)
|
||||
|
||||
for _, ent := range entities(workDir, dataDir, statsDir, querylogDir, confFilePath) {
|
||||
if ent.Value {
|
||||
checkDir(ctx, dirLoggger, ent.Key)
|
||||
} else {
|
||||
checkFile(ctx, fileLogger, ent.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkDir checks the permissions of a single directory. The results are
|
||||
// logged at the appropriate level.
|
||||
func checkDir(ctx context.Context, l *slog.Logger, dirPath string) {
|
||||
checkPath(ctx, l, dirPath, aghos.DefaultPermDir)
|
||||
}
|
||||
|
||||
// checkFile checks the permissions of a single file. The results are logged at
|
||||
// the appropriate level.
|
||||
func checkFile(ctx context.Context, l *slog.Logger, filePath string) {
|
||||
checkPath(ctx, l, filePath, aghos.DefaultPermFile)
|
||||
}
|
||||
60
internal/permcheck/check_windows.go
Normal file
60
internal/permcheck/check_windows.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build windows
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// check is the Windows-specific implementation of [Check].
|
||||
//
|
||||
// Note, that it only checks the owner and the ACEs of the working directory.
|
||||
// This is due to the assumption that the working directory ACEs are inherited
|
||||
// by the underlying files and directories, since at least [migrate] sets this
|
||||
// inheritance mode.
|
||||
func check(ctx context.Context, l *slog.Logger, workDir, _, _, _, _ string) {
|
||||
l = l.With("type", typeDir, "path", workDir)
|
||||
|
||||
dacl, owner, err := getSecurityInfo(workDir)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "getting security info", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !owner.IsWellKnown(windows.WinBuiltinAdministratorsSid) {
|
||||
l.WarnContext(ctx, "owner is not in administrators group")
|
||||
}
|
||||
|
||||
err = rangeACEs(dacl, func(
|
||||
hdr windows.ACE_HEADER,
|
||||
mask windows.ACCESS_MASK,
|
||||
sid *windows.SID,
|
||||
) (cont bool) {
|
||||
l.DebugContext(ctx, "checking access control entry", "mask", mask, "sid", sid)
|
||||
|
||||
warn := false
|
||||
switch {
|
||||
case hdr.AceType != windows.ACCESS_ALLOWED_ACE_TYPE:
|
||||
// Skip non-allowed ACEs.
|
||||
case !sid.IsWellKnown(windows.WinBuiltinAdministratorsSid):
|
||||
// Non-administrator ACEs should not have any access rights.
|
||||
warn = mask > 0
|
||||
default:
|
||||
// Administrators should full control access rights.
|
||||
warn = mask&fullControlMask != fullControlMask
|
||||
}
|
||||
if warn {
|
||||
l.WarnContext(ctx, "unexpected access control entry", "mask", mask, "sid", sid)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "checking access control entries", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// NeedsMigration returns true if AdGuard Home files need permission migration.
|
||||
//
|
||||
// TODO(a.garipov): Consider ways to detect this better.
|
||||
func NeedsMigration(confFilePath string) (ok bool) {
|
||||
s, err := aghos.Stat(confFilePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Likely a first run. Don't check.
|
||||
return false
|
||||
}
|
||||
|
||||
log.Error("permcheck: checking if files need migration: %s", err)
|
||||
|
||||
// Unexpected error. Try to migrate just in case.
|
||||
return true
|
||||
}
|
||||
|
||||
return s.Mode().Perm() != aghos.DefaultPermFile
|
||||
}
|
||||
|
||||
// Migrate attempts to change the permissions of AdGuard Home's files. It logs
|
||||
// the results at an appropriate level.
|
||||
func Migrate(workDir, dataDir, statsDir, querylogDir, confFilePath string) {
|
||||
chmodDir(workDir)
|
||||
|
||||
chmodFile(confFilePath)
|
||||
|
||||
// TODO(a.garipov): Put all paths in one place and remove this duplication.
|
||||
chmodDir(dataDir)
|
||||
chmodDir(filepath.Join(dataDir, "filters"))
|
||||
chmodFile(filepath.Join(dataDir, "sessions.db"))
|
||||
chmodFile(filepath.Join(dataDir, "leases.json"))
|
||||
|
||||
if dataDir != querylogDir {
|
||||
chmodDir(querylogDir)
|
||||
}
|
||||
chmodFile(filepath.Join(querylogDir, "querylog.json"))
|
||||
chmodFile(filepath.Join(querylogDir, "querylog.json.1"))
|
||||
|
||||
if dataDir != statsDir {
|
||||
chmodDir(statsDir)
|
||||
}
|
||||
chmodFile(filepath.Join(statsDir, "stats.db"))
|
||||
}
|
||||
|
||||
// chmodDir changes the permissions of a single directory. The results are
|
||||
// logged at the appropriate level.
|
||||
func chmodDir(dirPath string) {
|
||||
chmodPath(dirPath, typeDir, aghos.DefaultPermDir)
|
||||
}
|
||||
|
||||
// chmodFile changes the permissions of a single file. The results are logged
|
||||
// at the appropriate level.
|
||||
func chmodFile(filePath string) {
|
||||
chmodPath(filePath, typeFile, aghos.DefaultPermFile)
|
||||
}
|
||||
|
||||
// chmodPath changes the permissions of a single filesystem entity. The results
|
||||
// are logged at the appropriate level.
|
||||
func chmodPath(entPath, fileType string, fm fs.FileMode) {
|
||||
err := aghos.Chmod(entPath, fm)
|
||||
if err == nil {
|
||||
log.Info("permcheck: changed permissions for %s %q", fileType, entPath)
|
||||
|
||||
return
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
log.Debug("permcheck: changing permissions for %s %q: %s", fileType, entPath, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Error(
|
||||
"permcheck: SECURITY WARNING: cannot change permissions for %s %q to %#o: %s; "+
|
||||
"this can leave your system vulnerable, see "+
|
||||
"https://adguard-dns.io/kb/adguard-home/running-securely/#os-service-concerns",
|
||||
fileType,
|
||||
entPath,
|
||||
fm,
|
||||
err,
|
||||
)
|
||||
}
|
||||
66
internal/permcheck/migrate_unix.go
Normal file
66
internal/permcheck/migrate_unix.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build unix
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// needsMigration is a Unix-specific implementation of [NeedsMigration].
|
||||
//
|
||||
// TODO(a.garipov): Consider ways to detect this better.
|
||||
func needsMigration(ctx context.Context, l *slog.Logger, _, confFilePath string) (ok bool) {
|
||||
s, err := os.Stat(confFilePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Likely a first run. Don't check.
|
||||
return false
|
||||
}
|
||||
|
||||
l.ErrorContext(ctx, "checking a need for permission migration", slogutil.KeyError, err)
|
||||
|
||||
// Unexpected error. Try to migrate just in case.
|
||||
return true
|
||||
}
|
||||
|
||||
return s.Mode().Perm() != aghos.DefaultPermFile
|
||||
}
|
||||
|
||||
// migrate is a Unix-specific implementation of [Migrate].
|
||||
func migrate(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
workDir string,
|
||||
dataDir string,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
confFilePath string,
|
||||
) {
|
||||
dirLoggger, fileLogger := l.With("type", typeDir), l.With("type", typeFile)
|
||||
|
||||
for _, ent := range entities(workDir, dataDir, statsDir, querylogDir, confFilePath) {
|
||||
if ent.Value {
|
||||
chmodDir(ctx, dirLoggger, ent.Key)
|
||||
} else {
|
||||
chmodFile(ctx, fileLogger, ent.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// chmodDir changes the permissions of a single directory. The results are
|
||||
// logged at the appropriate level.
|
||||
func chmodDir(ctx context.Context, l *slog.Logger, dirPath string) {
|
||||
chmodPath(ctx, l, dirPath, aghos.DefaultPermDir)
|
||||
}
|
||||
|
||||
// chmodFile changes the permissions of a single file. The results are logged
|
||||
// at the appropriate level.
|
||||
func chmodFile(ctx context.Context, l *slog.Logger, filePath string) {
|
||||
chmodPath(ctx, l, filePath, aghos.DefaultPermFile)
|
||||
}
|
||||
135
internal/permcheck/migrate_windows.go
Normal file
135
internal/permcheck/migrate_windows.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//go:build windows
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// needsMigration is the Windows-specific implementation of [NeedsMigration].
|
||||
func needsMigration(ctx context.Context, l *slog.Logger, workDir, _ string) (ok bool) {
|
||||
l = l.With("type", typeDir, "path", workDir)
|
||||
|
||||
dacl, owner, err := getSecurityInfo(workDir)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "getting security info", slogutil.KeyError, err)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if !owner.IsWellKnown(windows.WinBuiltinAdministratorsSid) {
|
||||
return true
|
||||
}
|
||||
|
||||
err = rangeACEs(dacl, func(
|
||||
hdr windows.ACE_HEADER,
|
||||
mask windows.ACCESS_MASK,
|
||||
sid *windows.SID,
|
||||
) (cont bool) {
|
||||
switch {
|
||||
case hdr.AceType != windows.ACCESS_ALLOWED_ACE_TYPE:
|
||||
// Skip non-allowed access control entries.
|
||||
l.DebugContext(ctx, "skipping deny access control entry", "sid", sid)
|
||||
case !sid.IsWellKnown(windows.WinBuiltinAdministratorsSid):
|
||||
// Non-administrator access control entries should not have any
|
||||
// access rights.
|
||||
ok = mask > 0
|
||||
default:
|
||||
// Administrators should have full control.
|
||||
ok = mask&fullControlMask != fullControlMask
|
||||
}
|
||||
|
||||
// Stop ranging if the access control entry is unexpected.
|
||||
return !ok
|
||||
})
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "checking access control entries", slogutil.KeyError, err)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// migrate is the Windows-specific implementation of [Migrate].
|
||||
//
|
||||
// It sets the owner to administrators and adds a full control access control
|
||||
// entry for the account. It also removes all non-administrator access control
|
||||
// entries, and keeps deny access control entries. For any created or modified
|
||||
// entry it sets the propagation flags to be inherited by child objects.
|
||||
func migrate(ctx context.Context, logger *slog.Logger, workDir, _, _, _, _ string) {
|
||||
l := logger.With("type", typeDir, "path", workDir)
|
||||
|
||||
dacl, owner, err := getSecurityInfo(workDir)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "getting security info", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
admins, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "creating administrators sid", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Check for duplicates?
|
||||
var accessEntries []windows.EXPLICIT_ACCESS
|
||||
var setACL bool
|
||||
// Iterate over the access control entries in DACL to determine if its
|
||||
// migration is needed.
|
||||
err = rangeACEs(dacl, func(
|
||||
hdr windows.ACE_HEADER,
|
||||
mask windows.ACCESS_MASK,
|
||||
sid *windows.SID,
|
||||
) (cont bool) {
|
||||
switch {
|
||||
case hdr.AceType != windows.ACCESS_ALLOWED_ACE_TYPE:
|
||||
// Add non-allowed access control entries as is, since they specify
|
||||
// the access restrictions, which shouldn't be lost.
|
||||
l.InfoContext(ctx, "migrating deny access control entry", "sid", sid)
|
||||
accessEntries = append(accessEntries, newDenyExplicitAccess(sid, mask))
|
||||
setACL = true
|
||||
case !sid.IsWellKnown(windows.WinBuiltinAdministratorsSid):
|
||||
// Remove non-administrator ACEs, since such accounts should not
|
||||
// have any access rights.
|
||||
l.InfoContext(ctx, "removing access control entry", "sid", sid)
|
||||
setACL = true
|
||||
default:
|
||||
// Administrators should have full control. Don't add a new entry
|
||||
// here since it will be added later in case there are other
|
||||
// required entries.
|
||||
l.InfoContext(ctx, "migrating access control entry", "sid", sid, "mask", mask)
|
||||
setACL = setACL || mask&fullControlMask != fullControlMask
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "ranging through access control entries", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if setACL {
|
||||
accessEntries = append(accessEntries, newFullExplicitAccess(admins))
|
||||
}
|
||||
|
||||
if !owner.IsWellKnown(windows.WinBuiltinAdministratorsSid) {
|
||||
l.InfoContext(ctx, "migrating owner", "sid", owner)
|
||||
owner = admins
|
||||
} else {
|
||||
l.DebugContext(ctx, "owner is already an administrator")
|
||||
owner = nil
|
||||
}
|
||||
|
||||
err = setSecurityInfo(workDir, owner, accessEntries)
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "setting security info", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,10 @@
|
||||
// Package permcheck contains code for simplifying permissions checks on files
|
||||
// and directories.
|
||||
//
|
||||
// TODO(a.garipov): Improve the approach on Windows.
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// File type constants for logging.
|
||||
@@ -22,65 +15,33 @@ const (
|
||||
|
||||
// Check checks the permissions on important files. It logs the results at
|
||||
// appropriate levels.
|
||||
func Check(workDir, dataDir, statsDir, querylogDir, confFilePath string) {
|
||||
checkDir(workDir)
|
||||
|
||||
checkFile(confFilePath)
|
||||
|
||||
// TODO(a.garipov): Put all paths in one place and remove this duplication.
|
||||
checkDir(dataDir)
|
||||
checkDir(filepath.Join(dataDir, "filters"))
|
||||
checkFile(filepath.Join(dataDir, "sessions.db"))
|
||||
checkFile(filepath.Join(dataDir, "leases.json"))
|
||||
|
||||
if dataDir != querylogDir {
|
||||
checkDir(querylogDir)
|
||||
}
|
||||
checkFile(filepath.Join(querylogDir, "querylog.json"))
|
||||
checkFile(filepath.Join(querylogDir, "querylog.json.1"))
|
||||
|
||||
if dataDir != statsDir {
|
||||
checkDir(statsDir)
|
||||
}
|
||||
checkFile(filepath.Join(statsDir, "stats.db"))
|
||||
func Check(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
workDir string,
|
||||
dataDir string,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
confFilePath string,
|
||||
) {
|
||||
check(ctx, l, workDir, dataDir, statsDir, querylogDir, confFilePath)
|
||||
}
|
||||
|
||||
// checkDir checks the permissions of a single directory. The results are
|
||||
// logged at the appropriate level.
|
||||
func checkDir(dirPath string) {
|
||||
checkPath(dirPath, typeDir, aghos.DefaultPermDir)
|
||||
// NeedsMigration returns true if AdGuard Home files need permission migration.
|
||||
func NeedsMigration(ctx context.Context, l *slog.Logger, workDir, confFilePath string) (ok bool) {
|
||||
return needsMigration(ctx, l, workDir, confFilePath)
|
||||
}
|
||||
|
||||
// checkFile checks the permissions of a single file. The results are logged at
|
||||
// the appropriate level.
|
||||
func checkFile(filePath string) {
|
||||
checkPath(filePath, typeFile, aghos.DefaultPermFile)
|
||||
}
|
||||
|
||||
// checkPath checks the permissions of a single filesystem entity. The results
|
||||
// are logged at the appropriate level.
|
||||
func checkPath(entPath, fileType string, want fs.FileMode) {
|
||||
s, err := aghos.Stat(entPath)
|
||||
if err != nil {
|
||||
logFunc := log.Error
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
logFunc = log.Debug
|
||||
}
|
||||
|
||||
logFunc("permcheck: checking %s %q: %s", fileType, entPath, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Add a more fine-grained check and result reporting.
|
||||
perm := s.Mode().Perm()
|
||||
if perm != want {
|
||||
log.Info(
|
||||
"permcheck: SECURITY WARNING: %s %q has unexpected permissions %#o; want %#o",
|
||||
fileType,
|
||||
entPath,
|
||||
perm,
|
||||
want,
|
||||
)
|
||||
}
|
||||
// Migrate attempts to change the permissions of AdGuard Home's files. It logs
|
||||
// the results at an appropriate level.
|
||||
func Migrate(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
workDir string,
|
||||
dataDir string,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
confFilePath string,
|
||||
) {
|
||||
migrate(ctx, l, workDir, dataDir, statsDir, querylogDir, confFilePath)
|
||||
}
|
||||
|
||||
123
internal/permcheck/security_unix.go
Normal file
123
internal/permcheck/security_unix.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unix
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// entity is a filesystem entity with a path and a flag indicating whether it is
|
||||
// a directory.
|
||||
type entity = container.KeyValue[string, bool]
|
||||
|
||||
// entities returns a list of filesystem entities that need to be ranged over.
|
||||
//
|
||||
// TODO(a.garipov): Put all paths in one place and remove this duplication.
|
||||
func entities(workDir, dataDir, statsDir, querylogDir, confFilePath string) (ents []entity) {
|
||||
ents = []entity{{
|
||||
Key: workDir,
|
||||
Value: true,
|
||||
}, {
|
||||
Key: confFilePath,
|
||||
Value: false,
|
||||
}, {
|
||||
Key: dataDir,
|
||||
Value: true,
|
||||
}, {
|
||||
Key: filepath.Join(dataDir, "filters"),
|
||||
Value: true,
|
||||
}, {
|
||||
Key: filepath.Join(dataDir, "sessions.db"),
|
||||
Value: false,
|
||||
}, {
|
||||
Key: filepath.Join(dataDir, "leases.json"),
|
||||
Value: false,
|
||||
}}
|
||||
|
||||
if dataDir != querylogDir {
|
||||
ents = append(ents, entity{
|
||||
Key: querylogDir,
|
||||
Value: true,
|
||||
})
|
||||
}
|
||||
ents = append(ents, entity{
|
||||
Key: filepath.Join(querylogDir, "querylog.json"),
|
||||
Value: false,
|
||||
}, entity{
|
||||
Key: filepath.Join(querylogDir, "querylog.json.1"),
|
||||
Value: false,
|
||||
})
|
||||
|
||||
if dataDir != statsDir {
|
||||
ents = append(ents, entity{
|
||||
Key: statsDir,
|
||||
Value: true,
|
||||
})
|
||||
}
|
||||
ents = append(ents, entity{
|
||||
Key: filepath.Join(statsDir, "stats.db"),
|
||||
})
|
||||
|
||||
return ents
|
||||
}
|
||||
|
||||
// checkPath checks the permissions of a single filesystem entity. The results
|
||||
// are logged at the appropriate level.
|
||||
func checkPath(ctx context.Context, l *slog.Logger, entPath string, want fs.FileMode) {
|
||||
l = l.With("path", entPath)
|
||||
|
||||
s, err := os.Stat(entPath)
|
||||
if err != nil {
|
||||
lvl := slog.LevelError
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
lvl = slog.LevelDebug
|
||||
}
|
||||
|
||||
l.Log(ctx, lvl, "checking permissions", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Add a more fine-grained check and result reporting.
|
||||
perm := s.Mode().Perm()
|
||||
if perm == want {
|
||||
return
|
||||
}
|
||||
|
||||
permOct, wantOct := fmt.Sprintf("%#o", perm), fmt.Sprintf("%#o", want)
|
||||
l.WarnContext(ctx, "found unexpected permissions", "perm", permOct, "want", wantOct)
|
||||
}
|
||||
|
||||
// chmodPath changes the permissions of a single filesystem entity. The results
|
||||
// are logged at the appropriate level.
|
||||
func chmodPath(ctx context.Context, l *slog.Logger, entPath string, fm fs.FileMode) {
|
||||
var lvl slog.Level
|
||||
var msg string
|
||||
args := []any{"path", entPath}
|
||||
|
||||
switch err := os.Chmod(entPath, fm); {
|
||||
case err == nil:
|
||||
lvl = slog.LevelInfo
|
||||
msg = "changed permissions"
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
lvl = slog.LevelDebug
|
||||
msg = "checking permissions"
|
||||
args = append(args, slogutil.KeyError, err)
|
||||
default:
|
||||
lvl = slog.LevelError
|
||||
msg = "cannot change permissions; this can leave your system vulnerable, see " +
|
||||
"https://adguard-dns.io/kb/adguard-home/running-securely/#os-service-concerns"
|
||||
args = append(args, "target_perm", fmt.Sprintf("%#o", fm), slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
l.Log(ctx, lvl, msg, args...)
|
||||
}
|
||||
167
internal/permcheck/security_windows.go
Normal file
167
internal/permcheck/security_windows.go
Normal file
@@ -0,0 +1,167 @@
|
||||
//go:build windows
|
||||
|
||||
package permcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// objectType is the type of the object for directories in context of security
|
||||
// API.
|
||||
const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT
|
||||
|
||||
// fileDeleteChildRight is the mask bit for the right to delete a child object.
|
||||
// It seems to be missing from the [windows] package.
|
||||
//
|
||||
// See https://learn.microsoft.com/en-us/windows-hardware/drivers/ifs/access-mask.
|
||||
const fileDeleteChildRight windows.ACCESS_MASK = 0b0100_0000
|
||||
|
||||
// fullControlMask is the mask for full control access rights.
|
||||
const fullControlMask windows.ACCESS_MASK = windows.FILE_LIST_DIRECTORY |
|
||||
windows.FILE_WRITE_DATA |
|
||||
windows.FILE_APPEND_DATA |
|
||||
windows.FILE_READ_EA |
|
||||
windows.FILE_WRITE_EA |
|
||||
windows.FILE_TRAVERSE |
|
||||
fileDeleteChildRight |
|
||||
windows.FILE_READ_ATTRIBUTES |
|
||||
windows.FILE_WRITE_ATTRIBUTES |
|
||||
windows.DELETE |
|
||||
windows.READ_CONTROL |
|
||||
windows.WRITE_DAC |
|
||||
windows.WRITE_OWNER |
|
||||
windows.SYNCHRONIZE
|
||||
|
||||
// aceFunc is a function that handles access control entries in the
|
||||
// discretionary access control list. It should return true to continue
|
||||
// iterating over the entries, or false to stop.
|
||||
type aceFunc = func(
|
||||
hdr windows.ACE_HEADER,
|
||||
mask windows.ACCESS_MASK,
|
||||
sid *windows.SID,
|
||||
) (cont bool)
|
||||
|
||||
// rangeACEs ranges over the access control entries in the discretionary access
|
||||
// control list of the specified security descriptor and calls f for each one.
|
||||
func rangeACEs(dacl *windows.ACL, f aceFunc) (err error) {
|
||||
var errs []error
|
||||
for i := range uint32(dacl.AceCount) {
|
||||
var ace *windows.ACCESS_ALLOWED_ACE
|
||||
err = windows.GetAce(dacl, i, &ace)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("getting entry at index %d: %w", i, err))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
sid := (*windows.SID)(unsafe.Pointer(&ace.SidStart))
|
||||
if !f(ace.Header, ace.Mask, sid) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err = errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("checking access control entries: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setSecurityInfo sets the security information on the specified file, using
|
||||
// ents to create a discretionary access control list. Either owner or ents can
|
||||
// be nil, in which case the corresponding information is not set, but at least
|
||||
// one of them should be specified.
|
||||
func setSecurityInfo(fname string, owner *windows.SID, ents []windows.EXPLICIT_ACCESS) (err error) {
|
||||
var secInfo windows.SECURITY_INFORMATION
|
||||
|
||||
var acl *windows.ACL
|
||||
if len(ents) > 0 {
|
||||
// TODO(e.burkov): Investigate if this whole set is necessary.
|
||||
secInfo |= windows.DACL_SECURITY_INFORMATION |
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION |
|
||||
windows.UNPROTECTED_DACL_SECURITY_INFORMATION
|
||||
|
||||
acl, err = windows.ACLFromEntries(ents, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating access control list: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if owner != nil {
|
||||
secInfo |= windows.OWNER_SECURITY_INFORMATION
|
||||
}
|
||||
|
||||
if secInfo == 0 {
|
||||
return errors.Error("no security information to set")
|
||||
}
|
||||
|
||||
err = windows.SetNamedSecurityInfo(fname, objectType, secInfo, owner, nil, acl, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting security info: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSecurityInfo retrieves the security information for the specified file.
|
||||
func getSecurityInfo(fname string) (dacl *windows.ACL, owner *windows.SID, err error) {
|
||||
// desiredSecInfo defines the parts of a security descriptor to retrieve.
|
||||
const desiredSecInfo windows.SECURITY_INFORMATION = windows.OWNER_SECURITY_INFORMATION |
|
||||
windows.DACL_SECURITY_INFORMATION |
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION |
|
||||
windows.UNPROTECTED_DACL_SECURITY_INFORMATION
|
||||
|
||||
sd, err := windows.GetNamedSecurityInfo(fname, objectType, desiredSecInfo)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("getting security descriptor: %w", err)
|
||||
}
|
||||
|
||||
owner, _, err = sd.Owner()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("getting owner sid: %w", err)
|
||||
}
|
||||
|
||||
dacl, _, err = sd.DACL()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("getting discretionary access control list: %w", err)
|
||||
}
|
||||
|
||||
return dacl, owner, nil
|
||||
}
|
||||
|
||||
// newFullExplicitAccess creates a new explicit access entry with full control
|
||||
// permissions.
|
||||
func newFullExplicitAccess(sid *windows.SID) (accEnt windows.EXPLICIT_ACCESS) {
|
||||
return windows.EXPLICIT_ACCESS{
|
||||
AccessPermissions: fullControlMask,
|
||||
AccessMode: windows.GRANT_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_UNKNOWN,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(sid),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newDenyExplicitAccess creates a new explicit access entry with specified deny
|
||||
// permissions.
|
||||
func newDenyExplicitAccess(
|
||||
sid *windows.SID,
|
||||
mask windows.ACCESS_MASK,
|
||||
) (accEnt windows.EXPLICIT_ACCESS) {
|
||||
return windows.EXPLICIT_ACCESS{
|
||||
AccessPermissions: mask,
|
||||
AccessMode: windows.DENY_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_UNKNOWN,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(sid),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{
|
||||
}
|
||||
|
||||
// decodeResultRuleKey decodes the token of "Rules" type to logEntry struct.
|
||||
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultRuleKey(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
i int,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
var vToken json.Token
|
||||
switch key {
|
||||
case "FilterListID":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if n, ok := vToken.(json.Number); ok {
|
||||
id, _ := n.Int64()
|
||||
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
|
||||
}
|
||||
case "IP":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if ipStr, ok := vToken.(string); ok {
|
||||
if ip, err := netip.ParseAddr(ipStr); err == nil {
|
||||
ent.Result.Rules[i].IP = ip
|
||||
} else {
|
||||
log.Debug("querylog: decoding ipStr value: %s", err)
|
||||
l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
case "Text":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if s, ok := vToken.(string); ok {
|
||||
ent.Result.Rules[i].Text = s
|
||||
}
|
||||
@@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
||||
|
||||
// decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule]
|
||||
// and then adds the decoded object to the slice of result rules.
|
||||
func decodeVTokenAndAddRule(
|
||||
func (l *queryLog) decodeVTokenAndAddRule(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
i int,
|
||||
dec *json.Decoder,
|
||||
@@ -215,7 +223,12 @@ func decodeVTokenAndAddRule(
|
||||
vToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultRuleKey %s err: %s", key, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"decoding result rule key",
|
||||
"key", key,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
return newRules, nil
|
||||
@@ -230,12 +243,14 @@ func decodeVTokenAndAddRule(
|
||||
|
||||
// decodeResultRules parses the dec's tokens into logEntry ent interpreting it
|
||||
// as a slice of the result rules.
|
||||
func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result rules"
|
||||
|
||||
for {
|
||||
delimToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultRules err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -244,13 +259,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
||||
if d, ok := delimToken.(json.Delim); !ok {
|
||||
return
|
||||
} else if d != '[' {
|
||||
log.Debug("decodeResultRules: unexpected delim %q", d)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(d),
|
||||
)
|
||||
}
|
||||
|
||||
err = decodeResultRuleToken(dec, ent)
|
||||
err = l.decodeResultRuleToken(ctx, dec, ent)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResultRules err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; rule token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -259,7 +278,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
||||
}
|
||||
|
||||
// decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent.
|
||||
func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
||||
func (l *queryLog) decodeResultRuleToken(
|
||||
ctx context.Context,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) (err error) {
|
||||
i := 0
|
||||
for {
|
||||
var keyToken json.Token
|
||||
@@ -287,7 +310,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
||||
return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken)
|
||||
}
|
||||
|
||||
decodeResultRuleKey(key, i, dec, ent)
|
||||
l.decodeResultRuleKey(ctx, key, i, dec, ent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,12 +319,14 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
||||
// other occurrences of DNSRewriteResult in the entry since hosts container's
|
||||
// rewrites currently has the highest priority along the entire filtering
|
||||
// pipeline.
|
||||
func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result reverse hosts"
|
||||
|
||||
for {
|
||||
itemToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultReverseHosts err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -315,7 +340,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("decodeResultReverseHosts: unexpected delim %q", v)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(v),
|
||||
)
|
||||
|
||||
return
|
||||
case string:
|
||||
@@ -346,12 +375,14 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
||||
|
||||
// decodeResultIPList parses the dec's tokens into logEntry ent interpreting it
|
||||
// as the result IP addresses list.
|
||||
func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result ip list"
|
||||
|
||||
for {
|
||||
itemToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultIPList err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -365,7 +396,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("decodeResultIPList: unexpected delim %q", v)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(v),
|
||||
)
|
||||
|
||||
return
|
||||
case string:
|
||||
@@ -382,7 +417,14 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
||||
|
||||
// decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type
|
||||
// to the logEntry struct.
|
||||
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultDNSRewriteResultKey(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
const msgPrefix = "decoding result dns rewrite result key"
|
||||
|
||||
var err error
|
||||
|
||||
switch key {
|
||||
@@ -391,7 +433,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
||||
vToken, err = dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultDNSRewriteResultKey err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -419,7 +461,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
||||
// decoding and correct the values.
|
||||
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
|
||||
if err != nil {
|
||||
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; response", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
ent.parseDNSRewriteResultIPs()
|
||||
@@ -430,12 +472,18 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
||||
|
||||
// decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent
|
||||
// interpreting it as the result DNSRewriteResult.
|
||||
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultDNSRewriteResult(
|
||||
ctx context.Context,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
const msgPrefix = "decoding result dns rewrite result"
|
||||
|
||||
for {
|
||||
key, err := parseKeyToken(dec)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResultDNSRewriteResult: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -445,7 +493,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
||||
continue
|
||||
}
|
||||
|
||||
decodeResultDNSRewriteResultKey(key, dec, ent)
|
||||
l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -508,14 +556,16 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) {
|
||||
}
|
||||
|
||||
// decodeResult decodes a token of "Result" type to logEntry struct.
|
||||
func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result"
|
||||
|
||||
defer translateResult(ent)
|
||||
|
||||
for {
|
||||
key, err := parseKeyToken(dec)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResult: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -525,10 +575,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
continue
|
||||
}
|
||||
|
||||
decHandler, ok := resultDecHandlers[key]
|
||||
ok := l.resultDecHandler(ctx, key, dec, ent)
|
||||
if ok {
|
||||
decHandler(dec, ent)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -543,7 +591,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
}
|
||||
|
||||
if err = handler(val, ent); err != nil {
|
||||
log.Debug("decodeResult handler err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -636,16 +684,34 @@ var resultHandlers = map[string]logEntryHandler{
|
||||
},
|
||||
}
|
||||
|
||||
// resultDecHandlers is the map of decode handlers for various keys.
|
||||
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
|
||||
"ReverseHosts": decodeResultReverseHosts,
|
||||
"IPList": decodeResultIPList,
|
||||
"Rules": decodeResultRules,
|
||||
"DNSRewriteResult": decodeResultDNSRewriteResult,
|
||||
// resultDecHandlers calls a decode handler for key if there is one.
|
||||
func (l *queryLog) resultDecHandler(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) (ok bool) {
|
||||
ok = true
|
||||
switch name {
|
||||
case "ReverseHosts":
|
||||
l.decodeResultReverseHosts(ctx, dec, ent)
|
||||
case "IPList":
|
||||
l.decodeResultIPList(ctx, dec, ent)
|
||||
case "Rules":
|
||||
l.decodeResultRules(ctx, dec, ent)
|
||||
case "DNSRewriteResult":
|
||||
l.decodeResultDNSRewriteResult(ctx, dec, ent)
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// decodeLogEntry decodes string str to logEntry ent.
|
||||
func decodeLogEntry(ent *logEntry, str string) {
|
||||
func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) {
|
||||
const msgPrefix = "decoding log entry"
|
||||
|
||||
dec := json.NewDecoder(strings.NewReader(str))
|
||||
dec.UseNumber()
|
||||
|
||||
@@ -653,7 +719,7 @@ func decodeLogEntry(ent *logEntry, str string) {
|
||||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeLogEntry err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -665,13 +731,14 @@ func decodeLogEntry(ent *logEntry, str string) {
|
||||
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken)
|
||||
err = fmt.Errorf("%s: keyToken is %T (%[2]v) and not string", msgPrefix, keyToken)
|
||||
l.logger.DebugContext(ctx, msgPrefix, slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if key == "Result" {
|
||||
decodeResult(dec, ent)
|
||||
l.decodeResult(ctx, dec, ent)
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -687,9 +754,14 @@ func decodeLogEntry(ent *logEntry, str string) {
|
||||
}
|
||||
|
||||
if err = handler(val, ent); err != nil {
|
||||
log.Debug("decodeLogEntry handler err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newUnexpectedDelimiterError is a helper for creating informative errors.
|
||||
func newUnexpectedDelimiterError(d json.Delim) (err error) {
|
||||
return fmt.Errorf("unexpected delimiter: %q", d)
|
||||
}
|
||||
|
||||
@@ -3,27 +3,35 @@ package querylog
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Common constants for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestDecodeLogEntry(t *testing.T) {
|
||||
logOutput := &bytes.Buffer{}
|
||||
l := &queryLog{
|
||||
logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
ReplaceAttr: slogutil.RemoveTime,
|
||||
})),
|
||||
}
|
||||
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
|
||||
@@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
got := &logEntry{}
|
||||
decodeLogEntry(got, data)
|
||||
l.decodeLogEntry(ctx, got, data)
|
||||
|
||||
s := logOutput.String()
|
||||
assert.Empty(t, s)
|
||||
@@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}, {
|
||||
name: "bad_filter_id_old_rule",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`,
|
||||
want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n",
|
||||
want: `level=DEBUG msg="decoding result; handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`,
|
||||
}, {
|
||||
name: "bad_is_filtered",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n",
|
||||
want: `level=DEBUG msg="decoding log entry; token" err="invalid character 'o' in literal true (expecting 'u')"`,
|
||||
}, {
|
||||
name: "bad_elapsed",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`,
|
||||
@@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}, {
|
||||
name: "bad_time",
|
||||
log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`,
|
||||
}, {
|
||||
name: "bad_host",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
@@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}, {
|
||||
name: "very_bad_client_proto",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="invalid client proto: \"dog\""`,
|
||||
}, {
|
||||
name: "bad_answer",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
@@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}, {
|
||||
name: "very_bad_answer",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="illegal base64 data at input byte 61"`,
|
||||
}, {
|
||||
name: "bad_rule",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`,
|
||||
@@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
}, {
|
||||
name: "bad_reverse_hosts",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`,
|
||||
want: "decodeResultReverseHosts: unexpected delim \"{\"\n",
|
||||
want: `level=DEBUG msg="decoding result reverse hosts" err="unexpected delimiter: \"{\""`,
|
||||
}, {
|
||||
name: "bad_ip_list",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`,
|
||||
want: "decodeResultIPList: unexpected delim \"{\"\n",
|
||||
want: `level=DEBUG msg="decoding result ip list" err="unexpected delimiter: \"{\""`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
decodeLogEntry(new(logEntry), tc.log)
|
||||
|
||||
s := logOutput.String()
|
||||
l.decodeLogEntry(ctx, new(logEntry), tc.log)
|
||||
got := logOutput.String()
|
||||
if tc.want == "" {
|
||||
assert.Empty(t, s)
|
||||
assert.Empty(t, got)
|
||||
} else {
|
||||
assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// Remove newline.
|
||||
got = got[:len(got)-1]
|
||||
assert.Equal(t, tc.want, got)
|
||||
}
|
||||
|
||||
logOutput.Reset()
|
||||
@@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
|
||||
aaaa2 = aaaa1.Next()
|
||||
)
|
||||
|
||||
l := &queryLog{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
want *logEntry
|
||||
entry string
|
||||
@@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := &logEntry{}
|
||||
decodeLogEntry(e, tc.entry)
|
||||
l.decodeLogEntry(ctx, e, tc.entry)
|
||||
|
||||
assert.Equal(t, tc.want, e)
|
||||
})
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) {
|
||||
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
|
||||
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
|
||||
// errors are logged.
|
||||
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
|
||||
func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
@@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
|
||||
e.Answer, err = resp.Pack()
|
||||
err = errors.Annotate(err, "packing answer: %w")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Error("querylog: %s", err)
|
||||
l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
@@ -74,7 +75,8 @@ func (l *queryLog) initWeb() {
|
||||
|
||||
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
|
||||
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := parseSearchParams(r)
|
||||
ctx := r.Context()
|
||||
params, err := l.parseSearchParams(ctx, r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
|
||||
|
||||
@@ -87,18 +89,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
entries, oldest = l.search(params)
|
||||
entries, oldest = l.search(ctx, params)
|
||||
}()
|
||||
|
||||
resp := entriesToJSON(entries, oldest, l.anonymizer.Load())
|
||||
resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
|
||||
// API.
|
||||
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
|
||||
l.clear()
|
||||
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) {
|
||||
l.clear(r.Context())
|
||||
}
|
||||
|
||||
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
|
||||
@@ -280,11 +282,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
|
||||
}
|
||||
|
||||
// parseSearchCriterion parses a search criterion from the query parameter.
|
||||
func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
ok bool,
|
||||
sc searchCriterion,
|
||||
err error,
|
||||
) {
|
||||
func (l *queryLog) parseSearchCriterion(
|
||||
ctx context.Context,
|
||||
q url.Values,
|
||||
name string,
|
||||
ct criterionType,
|
||||
) (ok bool, sc searchCriterion, err error) {
|
||||
val := q.Get(name)
|
||||
if val == "" {
|
||||
return false, sc, nil
|
||||
@@ -301,7 +304,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
// TODO(e.burkov): Make it work with parts of IDNAs somehow.
|
||||
loweredVal := strings.ToLower(val)
|
||||
if asciiVal, err = idna.ToASCII(loweredVal); err != nil {
|
||||
log.Debug("can't convert %q to ascii: %s", val, err)
|
||||
l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err)
|
||||
} else if asciiVal == loweredVal {
|
||||
// Purge asciiVal to prevent checking the same value
|
||||
// twice.
|
||||
@@ -331,7 +334,10 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
|
||||
// parseSearchParams parses search parameters from the HTTP request's query
|
||||
// string.
|
||||
func parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
||||
func (l *queryLog) parseSearchParams(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
) (p *searchParams, err error) {
|
||||
p = newSearchParams()
|
||||
|
||||
q := r.URL.Query()
|
||||
@@ -369,7 +375,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
||||
}} {
|
||||
var ok bool
|
||||
var c searchCriterion
|
||||
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct)
|
||||
ok, c, err = l.parseSearchCriterion(ctx, q, v.urlField, v.ct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -8,7 +9,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
@@ -19,7 +20,8 @@ import (
|
||||
type jobject = map[string]any
|
||||
|
||||
// entriesToJSON converts query log entries to JSON.
|
||||
func entriesToJSON(
|
||||
func (l *queryLog) entriesToJSON(
|
||||
ctx context.Context,
|
||||
entries []*logEntry,
|
||||
oldest time.Time,
|
||||
anonFunc aghnet.IPMutFunc,
|
||||
@@ -28,7 +30,7 @@ func entriesToJSON(
|
||||
|
||||
// The elements order is already reversed to be from newer to older.
|
||||
for _, entry := range entries {
|
||||
jsonEntry := entryToJSON(entry, anonFunc)
|
||||
jsonEntry := l.entryToJSON(ctx, entry, anonFunc)
|
||||
data = append(data, jsonEntry)
|
||||
}
|
||||
|
||||
@@ -44,7 +46,11 @@ func entriesToJSON(
|
||||
}
|
||||
|
||||
// entryToJSON converts a log entry's data into an entry for the JSON API.
|
||||
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
|
||||
func (l *queryLog) entryToJSON(
|
||||
ctx context.Context,
|
||||
entry *logEntry,
|
||||
anonFunc aghnet.IPMutFunc,
|
||||
) (jsonEntry jobject) {
|
||||
hostname := entry.QHost
|
||||
question := jobject{
|
||||
"type": entry.QType,
|
||||
@@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
|
||||
}
|
||||
|
||||
if qhost, err := idna.ToUnicode(hostname); err != nil {
|
||||
log.Debug("querylog: translating %q into unicode: %s", hostname, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"translating into unicode",
|
||||
"hostname", hostname,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
} else if qhost != hostname && qhost != "" {
|
||||
question["unicode_name"] = qhost
|
||||
}
|
||||
@@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
|
||||
jsonEntry["service_name"] = entry.Result.ServiceName
|
||||
}
|
||||
|
||||
setMsgData(entry, jsonEntry)
|
||||
setOrigAns(entry, jsonEntry)
|
||||
l.setMsgData(ctx, entry, jsonEntry)
|
||||
l.setOrigAns(ctx, entry, jsonEntry)
|
||||
|
||||
return jsonEntry
|
||||
}
|
||||
|
||||
// setMsgData sets the message data in jsonEntry.
|
||||
func setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
if err := msg.Unpack(entry.Answer); err != nil {
|
||||
log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"unpacking dns message",
|
||||
"answer", entry.Answer,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
}
|
||||
|
||||
// setOrigAns sets the original answer data in jsonEntry.
|
||||
func setOrigAns(entry *logEntry, jsonEntry jobject) {
|
||||
func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.OrigAnswer) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) {
|
||||
orig := &dns.Msg{}
|
||||
err := orig.Unpack(entry.OrigAnswer)
|
||||
if err != nil {
|
||||
log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"setting original answer",
|
||||
"answer", entry.OrigAnswer,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,7 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json"
|
||||
|
||||
// queryLog is a structure that writes and reads the DNS query log.
|
||||
type queryLog struct {
|
||||
// logger is used for logging the operation of the query log. It must not
|
||||
// be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// confMu protects conf.
|
||||
confMu *sync.RWMutex
|
||||
|
||||
@@ -76,24 +82,34 @@ func NewClientProto(s string) (cp ClientProto, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *queryLog) Start() {
|
||||
// type check
|
||||
var _ QueryLog = (*queryLog)(nil)
|
||||
|
||||
// Start implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Start(ctx context.Context) (err error) {
|
||||
if l.conf.HTTPRegister != nil {
|
||||
l.initWeb()
|
||||
}
|
||||
|
||||
go l.periodicRotate()
|
||||
go l.periodicRotate(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) Close() {
|
||||
// Shutdown implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Shutdown(ctx context.Context) (err error) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
if l.conf.FileEnabled {
|
||||
err := l.flushLogBuffer()
|
||||
err = l.flushLogBuffer(ctx)
|
||||
if err != nil {
|
||||
log.Error("querylog: closing: %s", err)
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkInterval(ivl time.Duration) (ok bool) {
|
||||
@@ -123,6 +139,7 @@ func validateIvl(ivl time.Duration) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDiskConfig implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) WriteDiskConfig(c *Config) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
@@ -131,7 +148,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) {
|
||||
}
|
||||
|
||||
// Clear memory buffer and remove log files
|
||||
func (l *queryLog) clear() {
|
||||
func (l *queryLog) clear(ctx context.Context) {
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
@@ -146,19 +163,24 @@ func (l *queryLog) clear() {
|
||||
oldLogFile := l.logFile + ".1"
|
||||
err := os.Remove(oldLogFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing old log file %q: %s", oldLogFile, err)
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"removing old log file",
|
||||
"file", oldLogFile,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
err = os.Remove(l.logFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing log file %q: %s", l.logFile, err)
|
||||
l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: cleared")
|
||||
l.logger.DebugContext(ctx, "cleared")
|
||||
}
|
||||
|
||||
// newLogEntry creates an instance of logEntry from parameters.
|
||||
func newLogEntry(params *AddParams) (entry *logEntry) {
|
||||
func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) {
|
||||
q := params.Question.Question[0]
|
||||
qHost := aghnet.NormalizeDomain(q.Name)
|
||||
|
||||
@@ -187,8 +209,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
|
||||
entry.ReqECS = params.ReqECS.String()
|
||||
}
|
||||
|
||||
entry.addResponse(params.Answer, false)
|
||||
entry.addResponse(params.OrigAnswer, true)
|
||||
entry.addResponse(ctx, logger, params.Answer, false)
|
||||
entry.addResponse(ctx, logger, params.OrigAnswer, true)
|
||||
|
||||
return entry
|
||||
}
|
||||
@@ -209,9 +231,12 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
err := params.validate()
|
||||
if err != nil {
|
||||
log.Error("querylog: adding record: %s, skipping", err)
|
||||
l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -220,7 +245,7 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
params.Result = &filtering.Result{}
|
||||
}
|
||||
|
||||
entry := newLogEntry(params)
|
||||
entry := newLogEntry(ctx, l.logger, params)
|
||||
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
@@ -232,9 +257,9 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
|
||||
// TODO(s.chzhen): Fix occasional rewrite of entires.
|
||||
go func() {
|
||||
flushErr := l.flushLogBuffer()
|
||||
flushErr := l.flushLogBuffer(ctx)
|
||||
if flushErr != nil {
|
||||
log.Error("querylog: flushing after adding: %s", flushErr)
|
||||
l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -247,7 +272,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
|
||||
|
||||
c, err := l.findClient(ids)
|
||||
if err != nil {
|
||||
log.Error("querylog: finding client: %s", err)
|
||||
// TODO(s.chzhen): Pass context.
|
||||
l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if c != nil && c.IgnoreQueryLog {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -14,14 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// TestQueryLog tests adding and loading (with filtering) entries from disk and
|
||||
// memory.
|
||||
func TestQueryLog(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -30,16 +28,21 @@ func TestQueryLog(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
// Write to disk (first file).
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Start writing to the second file.
|
||||
require.NoError(t, l.rotate())
|
||||
require.NoError(t, l.rotate(ctx))
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
|
||||
// Write to disk.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Add memory entries.
|
||||
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
|
||||
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
|
||||
@@ -119,8 +122,9 @@ func TestQueryLog(t *testing.T) {
|
||||
params := newSearchParams()
|
||||
params.searchCriteria = tc.sCr
|
||||
|
||||
entries, _ := l.search(params)
|
||||
entries, _ := l.search(ctx, params)
|
||||
require.Len(t, entries, len(tc.want))
|
||||
|
||||
for _, want := range tc.want {
|
||||
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
|
||||
}
|
||||
@@ -130,6 +134,7 @@ func TestQueryLog(t *testing.T) {
|
||||
|
||||
func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
@@ -142,12 +147,16 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
firstPageDomain = "first.example.org"
|
||||
secondPageDomain = "second.example.org"
|
||||
)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add entries to the log.
|
||||
for range entNum {
|
||||
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to the first file.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Add more to the in-memory part of log.
|
||||
for range entNum {
|
||||
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
@@ -191,8 +200,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
params.offset = tc.offset
|
||||
params.limit = tc.limit
|
||||
entries, _ := l.search(params)
|
||||
|
||||
entries, _ := l.search(ctx, params)
|
||||
require.Len(t, entries, tc.wantLen)
|
||||
|
||||
if tc.wantLen > 0 {
|
||||
@@ -205,6 +213,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
|
||||
func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -213,20 +222,21 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
for range entNum {
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to disk.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
params := newSearchParams()
|
||||
|
||||
for _, maxFileScanEntries := range []int{5, 0} {
|
||||
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
|
||||
params.maxFileScanEntries = maxFileScanEntries
|
||||
entries, _ := l.search(params)
|
||||
entries, _ := l.search(ctx, params)
|
||||
assert.Len(t, entries, entNum-maxFileScanEntries)
|
||||
})
|
||||
}
|
||||
@@ -234,6 +244,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
|
||||
func TestQueryLogFileDisabled(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: false,
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -248,8 +259,10 @@ func TestQueryLogFileDisabled(t *testing.T) {
|
||||
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
|
||||
params := newSearchParams()
|
||||
ll, _ := l.search(params)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
ll, _ := l.search(ctx, params)
|
||||
require.Len(t, ll, 2)
|
||||
|
||||
assert.Equal(t, "example3.org", ll[0].QHost)
|
||||
assert.Equal(t, "example2.org", ll[1].QHost)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -57,7 +59,6 @@ type qLogFile struct {
|
||||
|
||||
// newQLogFile initializes a new instance of the qLogFile.
|
||||
func newQLogFile(path string) (qf *qLogFile, err error) {
|
||||
// Don't use [aghos.OpenFile] here, because the file is expected to exist.
|
||||
f, err := os.OpenFile(path, os.O_RDONLY, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,7 +103,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6
|
||||
// for so that when we call "ReadNext" this line was returned.
|
||||
// - Depth of the search (how many times we compared timestamps).
|
||||
// - If we could not find it, it returns one of the errors described above.
|
||||
func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
|
||||
func (q *qLogFile) seekTS(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
timestamp int64,
|
||||
) (pos int64, depth int, err error) {
|
||||
q.lock.Lock()
|
||||
defer q.lock.Unlock()
|
||||
|
||||
@@ -151,7 +156,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
|
||||
lastProbeLineIdx = lineIdx
|
||||
|
||||
// Get the timestamp from the query log record.
|
||||
ts := readQLogTimestamp(line)
|
||||
ts := readQLogTimestamp(ctx, logger, line)
|
||||
if ts == 0 {
|
||||
return 0, depth, fmt.Errorf(
|
||||
"looking up timestamp %d in %q: record %q has empty timestamp",
|
||||
@@ -385,20 +390,22 @@ func readJSONValue(s, prefix string) string {
|
||||
}
|
||||
|
||||
// readQLogTimestamp reads the timestamp field from the query log line.
|
||||
func readQLogTimestamp(str string) int64 {
|
||||
func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 {
|
||||
val := readJSONValue(str, `"T":"`)
|
||||
if len(val) == 0 {
|
||||
val = readJSONValue(str, `"Time":"`)
|
||||
}
|
||||
|
||||
if len(val) == 0 {
|
||||
log.Error("Couldn't find timestamp: %s", str)
|
||||
logger.ErrorContext(ctx, "couldn't find timestamp", "line", str)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
tm, err := time.Parse(time.RFC3339Nano, val)
|
||||
if err != nil {
|
||||
log.Error("Couldn't parse timestamp: %s", val)
|
||||
logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val, slogutil.KeyError, err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -24,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
|
||||
|
||||
f, err := os.CreateTemp(dir, "*.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use defer and not t.Cleanup to make sure that the file is closed
|
||||
// after this function is done.
|
||||
defer func() {
|
||||
@@ -108,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
|
||||
// Calculate the expected position.
|
||||
fileInfo, err := q.file.Stat()
|
||||
require.NoError(t, err)
|
||||
|
||||
var expPos int64
|
||||
if expPos = fileInfo.Size(); expPos > 0 {
|
||||
expPos--
|
||||
@@ -129,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
assert.Equal(t, tc.linesNum, read)
|
||||
})
|
||||
}
|
||||
@@ -146,6 +150,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
|
||||
num: 10,
|
||||
}}
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -171,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
|
||||
t.Run(l.name+"_"+tc.name, func(t *testing.T) {
|
||||
line, err := getQLogFileLine(q, tc.line)
|
||||
require.NoError(t, err)
|
||||
ts := readQLogTimestamp(line)
|
||||
|
||||
ts := readQLogTimestamp(ctx, logger, line)
|
||||
assert.NotEqualValues(t, 0, ts)
|
||||
|
||||
// Try seeking to that line now.
|
||||
pos, _, err := q.seekTS(ts)
|
||||
pos, _, err := q.seekTS(ctx, logger, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqualValues(t, 0, pos)
|
||||
|
||||
testLine, err := q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, line, testLine)
|
||||
})
|
||||
}
|
||||
@@ -199,6 +209,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
|
||||
num: 10,
|
||||
}}
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -221,14 +234,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
|
||||
|
||||
line, err := getQLogFileLine(q, l.num/2)
|
||||
require.NoError(t, err)
|
||||
testCases[2].ts = readQLogTimestamp(line) - 1
|
||||
|
||||
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.NotEqualValues(t, 0, tc.ts)
|
||||
|
||||
var depth int
|
||||
_, depth, err = q.seekTS(tc.ts)
|
||||
_, depth, err = q.seekTS(ctx, logger, tc.ts)
|
||||
assert.NotEmpty(t, l.num)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -262,11 +275,13 @@ func TestQLogFile(t *testing.T) {
|
||||
// Seek to the start.
|
||||
pos, err := q.SeekStart()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Greater(t, pos, int64(0))
|
||||
|
||||
// Read first line.
|
||||
line, err := q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, line, "0.0.0.2")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
@@ -274,6 +289,7 @@ func TestQLogFile(t *testing.T) {
|
||||
// Read second line.
|
||||
line, err = q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, 0, q.position)
|
||||
assert.Contains(t, line, "0.0.0.1")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
@@ -282,12 +298,14 @@ func TestQLogFile(t *testing.T) {
|
||||
// Try reading again (there's nothing to read anymore).
|
||||
line, err = q.ReadNext()
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
assert.Empty(t, line)
|
||||
}
|
||||
|
||||
func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
|
||||
f, err := os.CreateTemp(t.TempDir(), "*.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, f.Close)
|
||||
|
||||
_, err = f.WriteString(data)
|
||||
@@ -295,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
|
||||
|
||||
file, err = newQLogFile(f.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, file.Close)
|
||||
|
||||
return file
|
||||
@@ -308,6 +327,9 @@ func TestQLog_Seek(t *testing.T) {
|
||||
`{"T":"` + strV + `"}` + nl
|
||||
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
wantErr error
|
||||
name string
|
||||
@@ -340,8 +362,10 @@ func TestQLog_Seek(t *testing.T) {
|
||||
|
||||
q := newTestQLogFileData(t, data)
|
||||
|
||||
_, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano())
|
||||
ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
|
||||
_, depth, err := q.seekTS(ctx, logger, ts)
|
||||
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
|
||||
|
||||
assert.Equal(t, tc.wantDepth, depth)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// qLogReader allows reading from multiple query log files in the reverse
|
||||
@@ -16,6 +18,10 @@ import (
|
||||
// pointer to a particular query log file, and to a specific position in this
|
||||
// file, and it reads lines in reverse order starting from that position.
|
||||
type qLogReader struct {
|
||||
// logger is used for logging the operation of the query log reader. It
|
||||
// must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// qFiles is an array with the query log files. The order is from oldest
|
||||
// to newest.
|
||||
qFiles []*qLogFile
|
||||
@@ -25,7 +31,7 @@ type qLogReader struct {
|
||||
}
|
||||
|
||||
// newQLogReader initializes a qLogReader instance with the specified files.
|
||||
func newQLogReader(files []string) (*qLogReader, error) {
|
||||
func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) {
|
||||
qFiles := make([]*qLogFile, 0)
|
||||
|
||||
for _, f := range files {
|
||||
@@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) {
|
||||
// Close what we've already opened.
|
||||
cErr := closeQFiles(qFiles)
|
||||
if cErr != nil {
|
||||
log.Debug("querylog: closing files: %s", cErr)
|
||||
logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
@@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) {
|
||||
qFiles = append(qFiles, q)
|
||||
}
|
||||
|
||||
return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil
|
||||
return &qLogReader{
|
||||
logger: logger,
|
||||
qFiles: qFiles,
|
||||
currentFile: len(qFiles) - 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// seekTS performs binary search of a query log record with the specified
|
||||
// timestamp. If the record is found, it sets qLogReader's position to point
|
||||
// to that line, so that the next ReadNext call returned this line.
|
||||
func (r *qLogReader) seekTS(timestamp int64) (err error) {
|
||||
func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) {
|
||||
for i := len(r.qFiles) - 1; i >= 0; i-- {
|
||||
q := r.qFiles[i]
|
||||
_, _, err = q.seekTS(timestamp)
|
||||
_, _, err = q.seekTS(ctx, r.logger, timestamp)
|
||||
if err != nil {
|
||||
if errors.Is(err, errTSTooEarly) {
|
||||
// Look at the next file, since we've reached the end of this
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader
|
||||
|
||||
testFiles := prepareTestFiles(t, filesNum, linesNum)
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Create the new qLogReader instance.
|
||||
reader, err := newQLogReader(testFiles)
|
||||
reader, err := newQLogReader(ctx, logger, testFiles)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, reader)
|
||||
@@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) {
|
||||
|
||||
func TestQLogReader_Seek(t *testing.T) {
|
||||
r := newTestQLogReader(t, 2, 10000)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
@@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||
ts, err := time.Parse(time.RFC3339Nano, tc.time)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.seekTS(ts.UnixNano())
|
||||
err = r.seekTS(ctx, ts.UnixNano())
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package querylog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
@@ -12,20 +13,19 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// QueryLog - main interface
|
||||
// QueryLog is the query log interface for use by other packages.
|
||||
type QueryLog interface {
|
||||
Start()
|
||||
// Interface starts and stops the query log.
|
||||
service.Interface
|
||||
|
||||
// Close query log object
|
||||
Close()
|
||||
|
||||
// Add a log entry
|
||||
// Add adds a log entry.
|
||||
Add(params *AddParams)
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
// WriteDiskConfig writes the query log configuration to c.
|
||||
WriteDiskConfig(c *Config)
|
||||
|
||||
// ShouldLog returns true if request for the host should be logged.
|
||||
@@ -36,6 +36,10 @@ type QueryLog interface {
|
||||
//
|
||||
// Do not alter any fields of this structure after using it.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the query log. It must not
|
||||
// be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Ignored contains the list of host names, which should not be written to
|
||||
// log, and matches them.
|
||||
Ignored *aghnet.IgnoreEngine
|
||||
@@ -151,6 +155,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
|
||||
}
|
||||
|
||||
l = &queryLog{
|
||||
logger: conf.Logger,
|
||||
findClient: findClient,
|
||||
|
||||
buffer: container.NewRingBuffer[*logEntry](memSize),
|
||||
|
||||
@@ -2,6 +2,7 @@ package querylog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -9,28 +10,30 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// flushLogBuffer flushes the current buffer to file and resets the current
|
||||
// buffer.
|
||||
func (l *queryLog) flushLogBuffer() (err error) {
|
||||
func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
|
||||
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
b, err := l.encodeEntries()
|
||||
b, err := l.encodeEntries(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return l.flushToFile(b)
|
||||
return l.flushToFile(ctx, b)
|
||||
}
|
||||
|
||||
// encodeEntries returns JSON encoded log entries, logs estimated time, clears
|
||||
// the log buffer.
|
||||
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
||||
func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
@@ -55,8 +58,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := b.Len()
|
||||
elapsed := time.Since(start)
|
||||
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen))
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"serialized elements via json",
|
||||
"count", bufLen,
|
||||
"elapsed", elapsed,
|
||||
"size", datasize.ByteSize(size),
|
||||
"size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)),
|
||||
"time_per_entry", elapsed/time.Duration(bufLen),
|
||||
)
|
||||
|
||||
l.buffer.Clear()
|
||||
l.flushPending = false
|
||||
@@ -65,13 +77,13 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
||||
}
|
||||
|
||||
// flushToFile saves the encoded log entries to the query log file.
|
||||
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
|
||||
func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) {
|
||||
l.fileWriteLock.Lock()
|
||||
defer l.fileWriteLock.Unlock()
|
||||
|
||||
filename := l.logFile
|
||||
|
||||
f, err := aghos.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, aghos.DefaultPermFile)
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating file %q: %w", filename, err)
|
||||
}
|
||||
@@ -83,19 +95,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
|
||||
return fmt.Errorf("writing to file %q: %w", filename, err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: ok %q: %v bytes written", filename, n)
|
||||
l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) rotate() error {
|
||||
func (l *queryLog) rotate(ctx context.Context) error {
|
||||
from := l.logFile
|
||||
to := l.logFile + ".1"
|
||||
|
||||
err := os.Rename(from, to)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Debug("querylog: no log to rotate")
|
||||
l.logger.DebugContext(ctx, "no log to rotate")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -103,12 +115,12 @@ func (l *queryLog) rotate() error {
|
||||
return fmt.Errorf("failed to rename old file: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: renamed %s into %s", from, to)
|
||||
l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
|
||||
func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) {
|
||||
var f *os.File
|
||||
f, err = os.Open(l.logFile)
|
||||
if err != nil {
|
||||
@@ -130,15 +142,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
log.Debug("querylog: the oldest log entry: %s", val)
|
||||
l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (l *queryLog) periodicRotate() {
|
||||
defer log.OnPanic("querylog: rotating")
|
||||
func (l *queryLog) periodicRotate(ctx context.Context) {
|
||||
defer slogutil.RecoverAndLog(ctx, l.logger)
|
||||
|
||||
l.checkAndRotate()
|
||||
l.checkAndRotate(ctx)
|
||||
|
||||
// rotationCheckIvl is the period of time between checking the need for
|
||||
// rotating log files. It's smaller of any available rotation interval to
|
||||
@@ -151,13 +163,13 @@ func (l *queryLog) periodicRotate() {
|
||||
defer rotations.Stop()
|
||||
|
||||
for range rotations.C {
|
||||
l.checkAndRotate()
|
||||
l.checkAndRotate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndRotate rotates log files if those are older than the specified
|
||||
// rotation interval.
|
||||
func (l *queryLog) checkAndRotate() {
|
||||
func (l *queryLog) checkAndRotate(ctx context.Context) {
|
||||
var rotationIvl time.Duration
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
@@ -166,29 +178,30 @@ func (l *queryLog) checkAndRotate() {
|
||||
rotationIvl = l.conf.RotationIvl
|
||||
}()
|
||||
|
||||
oldest, err := l.readFileFirstTimeValue()
|
||||
oldest, err := l.readFileFirstTimeValue(ctx)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("querylog: reading oldest record for rotation: %s", err)
|
||||
l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
|
||||
log.Debug(
|
||||
"querylog: %s <= %s, not rotating",
|
||||
now.Format(time.RFC3339),
|
||||
rotTime.Format(time.RFC3339),
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"not rotating",
|
||||
"now", now.Format(time.RFC3339),
|
||||
"rotate_time", rotTime.Format(time.RFC3339),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = l.rotate()
|
||||
err = l.rotate(ctx)
|
||||
if err != nil {
|
||||
log.Error("querylog: rotating: %s", err)
|
||||
l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("querylog: rotated successfully")
|
||||
l.logger.DebugContext(ctx, "rotated successfully")
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// client finds the client info, if any, by its ClientID and IP address,
|
||||
@@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
|
||||
// buffer. It optionally uses the client cache, if provided. It also returns
|
||||
// the total amount of records in the buffer at the moment of searching.
|
||||
// l.confMu is expected to be locked.
|
||||
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) {
|
||||
func (l *queryLog) searchMemory(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
) (entries []*logEntry, total int) {
|
||||
// Check memory size, as the buffer can contain a single log record. See
|
||||
// [newQueryLog].
|
||||
if l.conf.MemSize == 0 {
|
||||
@@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
var err error
|
||||
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
|
||||
if err != nil {
|
||||
msg := "querylog: enriching memory record at time %s" +
|
||||
" for client %q (clientid %q): %s"
|
||||
log.Error(msg, e.Time, e.IP, e.ClientID, err)
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching memory record",
|
||||
"at", e.Time,
|
||||
"client_ip", e.IP,
|
||||
"client_id", e.ClientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
// Go on and try to match anyway.
|
||||
}
|
||||
@@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
// search searches log entries in memory buffer and log file using specified
|
||||
// parameters and returns the list of entries found and the time of the oldest
|
||||
// entry. l.confMu is expected to be locked.
|
||||
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
|
||||
func (l *queryLog) search(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
) (entries []*logEntry, oldest time.Time) {
|
||||
start := time.Now()
|
||||
|
||||
if params.limit == 0 {
|
||||
@@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
||||
|
||||
cache := clientCache{}
|
||||
|
||||
memoryEntries, bufLen := l.searchMemory(params, cache)
|
||||
log.Debug("querylog: got %d entries from memory", len(memoryEntries))
|
||||
memoryEntries, bufLen := l.searchMemory(ctx, params, cache)
|
||||
l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries))
|
||||
|
||||
fileEntries, oldest, total := l.searchFiles(params, cache)
|
||||
log.Debug("querylog: got %d entries from files", len(fileEntries))
|
||||
fileEntries, oldest, total := l.searchFiles(ctx, params, cache)
|
||||
l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries))
|
||||
|
||||
total += bufLen
|
||||
|
||||
@@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
||||
oldest = entries[len(entries)-1].Time
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"querylog: prepared data (%d/%d) older than %s in %s",
|
||||
len(entries),
|
||||
total,
|
||||
params.olderThan,
|
||||
time.Since(start),
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"prepared data",
|
||||
"count", len(entries),
|
||||
"total", total,
|
||||
"older_than", params.olderThan,
|
||||
"elapsed", time.Since(start),
|
||||
)
|
||||
|
||||
return entries, oldest
|
||||
@@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
||||
|
||||
// seekRecord changes the current position to the next record older than the
|
||||
// provided parameter.
|
||||
func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
|
||||
func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) {
|
||||
if olderThan.IsZero() {
|
||||
return r.SeekStart()
|
||||
}
|
||||
|
||||
err = r.seekTS(olderThan.UnixNano())
|
||||
err = r.seekTS(ctx, olderThan.UnixNano())
|
||||
if err == nil {
|
||||
// Read to the next record, because we only need the one that goes
|
||||
// after it.
|
||||
@@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
|
||||
|
||||
// setQLogReader creates a reader with the specified files and sets the
|
||||
// position to the next record older than the provided parameter.
|
||||
func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) {
|
||||
func (l *queryLog) setQLogReader(
|
||||
ctx context.Context,
|
||||
olderThan time.Time,
|
||||
) (qr *qLogReader, err error) {
|
||||
files := []string{
|
||||
l.logFile + ".1",
|
||||
l.logFile,
|
||||
}
|
||||
|
||||
r, err := newQLogReader(files)
|
||||
r, err := newQLogReader(ctx, l.logger, files)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening qlog reader: %s", err)
|
||||
return nil, fmt.Errorf("opening qlog reader: %w", err)
|
||||
}
|
||||
|
||||
err = r.seekRecord(olderThan)
|
||||
err = r.seekRecord(ctx, olderThan)
|
||||
if err != nil {
|
||||
defer func() { err = errors.WithDeferred(err, r.Close()) }()
|
||||
log.Debug("querylog: cannot seek to %s: %s", olderThan, err)
|
||||
l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error
|
||||
// calls faster so that the UI could handle it and show something quicker.
|
||||
// This behavior can be overridden if maxFileScanEntries is set to 0.
|
||||
func (l *queryLog) readEntries(
|
||||
ctx context.Context,
|
||||
r *qLogReader,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
totalLimit int,
|
||||
) (entries []*logEntry, oldestNano int64, total int) {
|
||||
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
|
||||
ent, ts, rErr := l.readNextEntry(r, params, cache)
|
||||
ent, ts, rErr := l.readNextEntry(ctx, r, params, cache)
|
||||
if rErr != nil {
|
||||
if rErr == io.EOF {
|
||||
oldestNano = 0
|
||||
@@ -205,7 +224,7 @@ func (l *queryLog) readEntries(
|
||||
break
|
||||
}
|
||||
|
||||
log.Error("querylog: reading next entry: %s", rErr)
|
||||
l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr)
|
||||
}
|
||||
|
||||
oldestNano = ts
|
||||
@@ -231,12 +250,13 @@ func (l *queryLog) readEntries(
|
||||
// and the total number of processed entries, including discarded ones,
|
||||
// correspondingly.
|
||||
func (l *queryLog) searchFiles(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
) (entries []*logEntry, oldest time.Time, total int) {
|
||||
r, err := l.setQLogReader(params.olderThan)
|
||||
r, err := l.setQLogReader(ctx, params.olderThan)
|
||||
if err != nil {
|
||||
log.Error("querylog: %s", err)
|
||||
l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
@@ -245,12 +265,12 @@ func (l *queryLog) searchFiles(
|
||||
|
||||
defer func() {
|
||||
if closeErr := r.Close(); closeErr != nil {
|
||||
log.Error("querylog: closing file: %s", closeErr)
|
||||
l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
totalLimit := params.offset + params.limit
|
||||
entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit)
|
||||
entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit)
|
||||
if oldestNano != 0 {
|
||||
oldest = time.Unix(0, oldestNano)
|
||||
}
|
||||
@@ -266,15 +286,21 @@ type quickMatchClientFinder struct {
|
||||
}
|
||||
|
||||
// findClient is a method that can be used as a quickMatchClientFinder.
|
||||
func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
|
||||
func (f quickMatchClientFinder) findClient(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
clientID string,
|
||||
ip string,
|
||||
) (c *Client) {
|
||||
var err error
|
||||
c, err = f.client(clientID, ip, f.cache)
|
||||
if err != nil {
|
||||
log.Error(
|
||||
"querylog: enriching file record for quick search: for client %q (clientid %q): %s",
|
||||
ip,
|
||||
clientID,
|
||||
err,
|
||||
logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching file record for quick search",
|
||||
"client_ip", ip,
|
||||
"client_id", clientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
|
||||
// the entry doesn't match the search criteria. ts is the timestamp of the
|
||||
// processed entry.
|
||||
func (l *queryLog) readNextEntry(
|
||||
ctx context.Context,
|
||||
r *qLogReader,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
@@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry(
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
if !params.quickMatch(line, clientFinder.findClient) {
|
||||
ts = readQLogTimestamp(line)
|
||||
if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) {
|
||||
ts = readQLogTimestamp(ctx, l.logger, line)
|
||||
|
||||
return nil, ts, nil
|
||||
}
|
||||
|
||||
e = &logEntry{}
|
||||
decodeLogEntry(e, line)
|
||||
l.decodeLogEntry(ctx, e, line)
|
||||
|
||||
if l.isIgnored(e.QHost) {
|
||||
return nil, ts, nil
|
||||
@@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry(
|
||||
|
||||
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
|
||||
if err != nil {
|
||||
log.Error(
|
||||
"querylog: enriching file record at time %s for client %q (clientid %q): %s",
|
||||
e.Time,
|
||||
e.IP,
|
||||
e.ClientID,
|
||||
err,
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching file record",
|
||||
"at", e.Time,
|
||||
"client_ip", e.IP,
|
||||
"client_id", e.ClientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
// Go on and try to match anyway.
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
}
|
||||
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
FindClient: findClient,
|
||||
BaseDir: t.TempDir(),
|
||||
RotationIvl: timeutil.Day,
|
||||
@@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
AnonymizeClientIP: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(l.Close)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return l.Shutdown(ctx)
|
||||
})
|
||||
|
||||
q := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
@@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
||||
olderThan: time.Now().Add(10 * time.Second),
|
||||
limit: 3,
|
||||
}
|
||||
entries, _ := l.search(sp)
|
||||
entries, _ := l.search(ctx, sp)
|
||||
assert.Equal(t, 2, findClientCalls)
|
||||
|
||||
require.Len(t, entries, 3)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict(
|
||||
// quickMatch quickly checks if the line matches the given search criterion.
|
||||
// It returns false if the like doesn't match. This method is only here for
|
||||
// optimization purposes.
|
||||
func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
|
||||
func (c *searchCriterion) quickMatch(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
line string,
|
||||
findClient quickMatchClientFunc,
|
||||
) (ok bool) {
|
||||
switch c.criterionType {
|
||||
case ctTerm:
|
||||
host := readJSONValue(line, `"QH":"`)
|
||||
@@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun
|
||||
clientID := readJSONValue(line, `"CID":"`)
|
||||
|
||||
var name string
|
||||
if cli := findClient(clientID, ip); cli != nil {
|
||||
if cli := findClient(ctx, logger, clientID, ip); cli != nil {
|
||||
name = cli.Name
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package querylog
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// searchParams represent the search query sent by the client.
|
||||
type searchParams struct {
|
||||
@@ -35,14 +39,23 @@ func newSearchParams() *searchParams {
|
||||
}
|
||||
|
||||
// quickMatchClientFunc is a simplified client finder for quick matches.
|
||||
type quickMatchClientFunc = func(clientID, ip string) (c *Client)
|
||||
type quickMatchClientFunc = func(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
clientID, ip string,
|
||||
) (c *Client)
|
||||
|
||||
// quickMatch quickly checks if the line matches the given search parameters.
|
||||
// It returns false if the line doesn't match. This method is only here for
|
||||
// optimization purposes.
|
||||
func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
|
||||
func (s *searchParams) quickMatch(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
line string,
|
||||
findClient quickMatchClientFunc,
|
||||
) (ok bool) {
|
||||
for _, c := range s.searchCriteria {
|
||||
if !c.quickMatch(line, findClient) {
|
||||
if !c.quickMatch(ctx, logger, line, findClient) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,12 +385,7 @@ func (s *StatsCtx) openDB() (err error) {
|
||||
|
||||
var db *bbolt.DB
|
||||
|
||||
opts := *bbolt.DefaultOptions
|
||||
// Use the custom OpenFile function to properly handle access rights on
|
||||
// Windows.
|
||||
opts.OpenFile = aghos.OpenFile
|
||||
|
||||
db, err = bbolt.Open(s.filename, aghos.DefaultPermFile, &opts)
|
||||
db, err = bbolt.Open(s.filename, aghos.DefaultPermFile, nil)
|
||||
if err != nil {
|
||||
if err.Error() == "invalid argument" {
|
||||
const lines = `AdGuard Home cannot be initialized due to an incompatible file system.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -234,18 +234,15 @@ func (a countPair) compareCount(b countPair) (res int) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertMapToSlice(m map[string]uint64, max int) (s []countPair) {
|
||||
func convertMapToSlice(m map[string]uint64, maxVal int) (s []countPair) {
|
||||
s = make([]countPair, 0, len(m))
|
||||
for k, v := range m {
|
||||
s = append(s, countPair{Name: k, Count: v})
|
||||
}
|
||||
|
||||
slices.SortFunc(s, countPair.compareCount)
|
||||
if max > len(s) {
|
||||
max = len(s)
|
||||
}
|
||||
|
||||
return s[:max]
|
||||
return s[:min(maxVal, len(s))]
|
||||
}
|
||||
|
||||
func convertSliceToMap(a []countPair) (m map[string]uint64) {
|
||||
@@ -611,9 +608,7 @@ func microsecondsToSeconds(n float64) (r float64) {
|
||||
func prepareTopUpstreamsAvgTime(
|
||||
upstreamsAvgTime topAddrsFloat,
|
||||
) (topUpstreamsAvgTime []topAddrsFloat) {
|
||||
keys := maps.Keys(upstreamsAvgTime)
|
||||
|
||||
slices.SortFunc(keys, func(a, b string) (res int) {
|
||||
keys := slices.SortedStableFunc(maps.Keys(upstreamsAvgTime), func(a, b string) (res int) {
|
||||
switch x, y := upstreamsAvgTime[a], upstreamsAvgTime[b]; {
|
||||
case x > y:
|
||||
return -1
|
||||
|
||||
@@ -1,65 +1,69 @@
|
||||
module github.com/AdguardTeam/AdGuardHome/internal/tools
|
||||
|
||||
go 1.23.2
|
||||
go 1.23.4
|
||||
|
||||
require (
|
||||
github.com/fzipp/gocyclo v0.6.0
|
||||
github.com/golangci/misspell v0.6.0
|
||||
github.com/gordonklaus/ineffassign v0.1.0
|
||||
github.com/jstemmer/go-junit-report/v2 v2.1.0
|
||||
github.com/kisielk/errcheck v1.8.0
|
||||
github.com/kyoh86/looppointer v0.2.1
|
||||
github.com/securego/gosec/v2 v2.21.4
|
||||
github.com/uudashr/gocognit v1.1.3
|
||||
golang.org/x/tools v0.26.0
|
||||
golang.org/x/tools v0.27.0
|
||||
golang.org/x/vuln v1.1.3
|
||||
honnef.co/go/tools v0.5.1
|
||||
mvdan.cc/gofumpt v0.7.0
|
||||
mvdan.cc/sh/v3 v3.10.0
|
||||
mvdan.cc/unparam v0.0.0-20240917084806-57a3b4290ba3
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/ai v0.8.2 // indirect
|
||||
cloud.google.com/go/auth v0.9.9 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
|
||||
cloud.google.com/go/ai v0.9.0 // indirect
|
||||
cloud.google.com/go/auth v0.11.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.5.2 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.2 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.3 // indirect
|
||||
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect
|
||||
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/generative-ai-go v0.18.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/renameio/v2 v2.0.0 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
|
||||
github.com/gookit/color v1.5.4 // indirect
|
||||
github.com/kyoh86/nolint v0.0.1 // indirect
|
||||
github.com/rogpeppe/go-internal v1.13.1 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
|
||||
go.opentelemetry.io/otel v1.31.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.31.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.31.0 // indirect
|
||||
golang.org/x/crypto v0.28.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c // indirect
|
||||
golang.org/x/mod v0.21.0 // indirect
|
||||
golang.org/x/net v0.30.0 // indirect
|
||||
golang.org/x/oauth2 v0.23.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
google.golang.org/api v0.203.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
|
||||
google.golang.org/grpc v1.67.1 // indirect
|
||||
google.golang.org/protobuf v1.35.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0 // indirect
|
||||
go.opentelemetry.io/otel v1.32.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.32.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.32.0 // indirect
|
||||
golang.org/x/crypto v0.29.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f // indirect
|
||||
golang.org/x/mod v0.22.0 // indirect
|
||||
golang.org/x/net v0.31.0 // indirect
|
||||
golang.org/x/oauth2 v0.24.0 // indirect
|
||||
golang.org/x/sync v0.9.0 // indirect
|
||||
golang.org/x/sys v0.27.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20241108154256-525ce2e96f55 // indirect
|
||||
golang.org/x/term v0.26.0 // indirect
|
||||
golang.org/x/text v0.20.0 // indirect
|
||||
golang.org/x/time v0.8.0 // indirect
|
||||
google.golang.org/api v0.209.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a // indirect
|
||||
google.golang.org/grpc v1.68.0 // indirect
|
||||
google.golang.org/protobuf v1.35.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
mvdan.cc/editorconfig v0.3.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
|
||||
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
|
||||
cloud.google.com/go/ai v0.8.2 h1:LEaQwqBv+k2ybrcdTtCTc9OPZXoEdcQaGrfvDYS6Bnk=
|
||||
cloud.google.com/go/ai v0.8.2/go.mod h1:Wb3EUUGWwB6yHBaUf/+oxUq/6XbCaU1yh0GrwUS8lr4=
|
||||
cloud.google.com/go/auth v0.9.9 h1:BmtbpNQozo8ZwW2t7QJjnrQtdganSdmqeIBxHxNkEZQ=
|
||||
cloud.google.com/go/auth v0.9.9/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc=
|
||||
cloud.google.com/go/ai v0.9.0 h1:r1Ig8O8+Qr3Ia3WfoO+gokD0fxB2Rk4quppuKjmGMsY=
|
||||
cloud.google.com/go/ai v0.9.0/go.mod h1:28bKM/oxmRgxmRgI1GLumFv+NSkt+DscAg/gF+54zzY=
|
||||
cloud.google.com/go/auth v0.11.0 h1:Ic5SZz2lsvbYcWT5dfjNWgw6tTlGi2Wc8hyQSC9BstA=
|
||||
cloud.google.com/go/auth v0.11.0/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
|
||||
cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo=
|
||||
cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k=
|
||||
cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc=
|
||||
cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
|
||||
cloud.google.com/go/longrunning v0.6.3 h1:A2q2vuyXysRcwzqDpMMLSI6mb6o39miS52UEG/Rd2ng=
|
||||
cloud.google.com/go/longrunning v0.6.3/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs=
|
||||
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -41,8 +41,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
@@ -67,12 +67,15 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 h1:5iH8iuqE5apketRbSFBy+X1V0o+l+8NF1avt4HWl7cA=
|
||||
github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
|
||||
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
|
||||
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
|
||||
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -80,22 +83,20 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s=
|
||||
github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A=
|
||||
github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o=
|
||||
github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/gordonklaus/ineffassign v0.1.0 h1:y2Gd/9I7MdY1oEIt+n+rowjBNDcLQq3RsH5hwJd0f9s=
|
||||
github.com/gordonklaus/ineffassign v0.1.0/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
|
||||
github.com/jstemmer/go-junit-report/v2 v2.1.0 h1:X3+hPYlSczH9IMIpSC9CQSZA0L+BipYafciZUWHEmsc=
|
||||
github.com/jstemmer/go-junit-report/v2 v2.1.0/go.mod h1:mgHVr7VUo5Tn8OLVr1cKnLuEy0M92wdRntM99h7RkgQ=
|
||||
github.com/kisielk/errcheck v1.8.0 h1:ZX/URYa7ilESY19ik/vBmCn6zdGQLxACwjAcWbHlYlg=
|
||||
github.com/kisielk/errcheck v1.8.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/FqKluHJQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kyoh86/looppointer v0.2.1 h1:Jx9fnkBj/JrIryBLMTYNTj9rvc2SrPS98Dg0w7fxdJg=
|
||||
github.com/kyoh86/looppointer v0.2.1/go.mod h1:q358WcM8cMWU+5vzqukvaZtnJi1kw/MpRHQm3xvTrjw=
|
||||
github.com/kyoh86/nolint v0.0.1 h1:GjNxDEkVn2wAxKHtP7iNTrRxytRZ1wXxLV5j4XzGfRU=
|
||||
github.com/kyoh86/nolint v0.0.1/go.mod h1:1ZiZZ7qqrZ9dZegU96phwVcdQOMKIqRzFJL3ewq9gtI=
|
||||
github.com/onsi/ginkgo/v2 v2.20.2 h1:7NVCeyIWROIAheY21RLS+3j2bb52W0W82tkberYytp4=
|
||||
github.com/onsi/ginkgo/v2 v2.20.2/go.mod h1:K9gyxPIlb+aIvnZ8bd9Ak+YP18w3APlR+5coaZoE2ag=
|
||||
github.com/onsi/gomega v1.34.2 h1:pNCwDkzrsv7MS9kpaQvVb1aVLahQXyJ/Tv5oAZMI3i8=
|
||||
@@ -103,8 +104,8 @@ github.com/onsi/gomega v1.34.2/go.mod h1:v1xfxRgk0KIsG+QOdm7p8UosrOzPYRo60fd3B/1
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/securego/gosec/v2 v2.21.4 h1:Le8MSj0PDmOnHJgUATjD96PaXRvCpKC+DGJvwyy0Mlk=
|
||||
github.com/securego/gosec/v2 v2.21.4/go.mod h1:Jtb/MwRQfRxCXyCm1rfM1BEiiiTfUOdyzzAhlr6lUTA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -121,111 +122,107 @@ github.com/uudashr/gocognit v1.1.3 h1:l+a111VcDbKfynh+airAy/DJQKaXh2m9vkoysMPSZy
|
||||
github.com/uudashr/gocognit v1.1.3/go.mod h1:aKH8/e8xbTRBwjbCkwZ8qt4l2EpKXl31KMHgSS+lZ2U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 h1:yMkBS9yViCc7U7yeLzJPM2XizlfdVvBRSmsQDWu6qc0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0/go.mod h1:n8MR6/liuGB5EmTETUBeU5ZgqMOlqKRxUaqPQBOANZ8=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
|
||||
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
|
||||
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
|
||||
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
|
||||
go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
|
||||
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
|
||||
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0 h1:qtFISDHKolvIxzSs0gIaiPUPR0Cucb0F2coHC7ZLdps=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0/go.mod h1:Y+Pop1Q6hCOnETWTW4NROK/q1hv50hM7yDaUTjG8lp8=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0 h1:DheMAlT6POBP+gh8RUH19EOTnQIor5QE0uSRPtzCpSw=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.57.0/go.mod h1:wZcGmeVO9nzP67aYSLDqXNWK87EZWhi7JWj1v7ZXf94=
|
||||
go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U=
|
||||
go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg=
|
||||
go.opentelemetry.io/otel/metric v1.32.0 h1:xV2umtmNcThh2/a/aCP+h64Xx5wsj8qqnkYZktzNa0M=
|
||||
go.opentelemetry.io/otel/metric v1.32.0/go.mod h1:jH7CIbbK6SH2V2wE16W05BHCtIDzauciCRLoc/SyMv8=
|
||||
go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM=
|
||||
go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
|
||||
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c h1:F/15/6p7LyGUSoP0GE5CB/U9+TNEER1foNOP5sWLLnI=
|
||||
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
|
||||
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak=
|
||||
golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f h1:WTyX8eCCyfdqiPYkRGm0MqElSfYFH3yR1+rl/mct9sA=
|
||||
golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
|
||||
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
|
||||
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
|
||||
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0 h1:od0RE4kmouF+x/o4zkTTSvBnGPZ2azGgCUmZdrbnEXM=
|
||||
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0/go.mod h1:8nZWdGp9pq73ZI//QJyckMQab3yq7hoWi7SI0UIusVI=
|
||||
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
|
||||
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20241108154256-525ce2e96f55 h1:ZZOVC4W26kVZSAW314SD81pWtiRgWNMbZsgLqKXx9lE=
|
||||
golang.org/x/telemetry v0.0.0-20241108154256-525ce2e96f55/go.mod h1:7Vh679jcBo81KQrd4wo0gKov7BE6IHwu1tEhHxHNM30=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU=
|
||||
golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
|
||||
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
|
||||
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
|
||||
golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o=
|
||||
golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q=
|
||||
golang.org/x/vuln v1.1.3 h1:NPGnvPOTgnjBc9HTaUx+nj+EaUYxl5SJOWqaDYGaFYw=
|
||||
golang.org/x/vuln v1.1.3/go.mod h1:7Le6Fadm5FOqE9C926BCD0g12NWyhg7cxV4BwcPFuNY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.203.0 h1:SrEeuwU3S11Wlscsn+LA1kb/Y5xT8uggJSkIhD08NAU=
|
||||
google.golang.org/api v0.203.0/go.mod h1:BuOVyCSYEPwJb3npWvDnNmFI92f3GeRnHNkETneT3SI=
|
||||
google.golang.org/api v0.209.0 h1:Ja2OXNlyRlWCWu8o+GgI4yUn/wz9h/5ZfFbKz+dQX+w=
|
||||
google.golang.org/api v0.209.0/go.mod h1:I53S168Yr/PNDNMi5yPnDc0/LGRZO6o7PoEbl/HY3CM=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 h1:2oV8dfuIkM1Ti7DwXc0BJfnwr9csz4TDXI9EmiI+Rbw=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38/go.mod h1:vuAjtvlwkDKF6L1GQ0SokiRLCGFfeBUXWr/aFFkHACc=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a h1:OAiGFfOiA0v9MRYsSidp3ubZaBnteRUyn3xB2ZQ5G/E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a/go.mod h1:jehYqy3+AhJU9ve55aNOaSml7wUXjF9x6z2LcCfpAhY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a h1:hgh8P4EuoxpsuKMXX/To36nOFD7vixReXgn8lPGnt+o=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
|
||||
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
|
||||
google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0=
|
||||
google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
@@ -235,8 +232,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
|
||||
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
|
||||
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -246,7 +243,11 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.5.1 h1:4bH5o3b5ZULQ4UrBmP+63W9r7qIkqJClEA9ko5YKx+I=
|
||||
honnef.co/go/tools v0.5.1/go.mod h1:e9irvo83WDG9/irijV44wr3tbhcFeRnfpVlRqVwpzMs=
|
||||
mvdan.cc/editorconfig v0.3.0 h1:D1D2wLYEYGpawWT5SpM5pRivgEgXjtEXwC9MWhEY0gQ=
|
||||
mvdan.cc/editorconfig v0.3.0/go.mod h1:NcJHuDtNOTEJ6251indKiWuzK6+VcrMuLzGMLKBFupQ=
|
||||
mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU=
|
||||
mvdan.cc/gofumpt v0.7.0/go.mod h1:txVFJy/Sc/mvaycET54pV8SW8gWxTlUuGHVEcncmNUo=
|
||||
mvdan.cc/sh/v3 v3.10.0 h1:v9z7N1DLZ7owyLM/SXZQkBSXcwr2IGMm2LY2pmhVXj4=
|
||||
mvdan.cc/sh/v3 v3.10.0/go.mod h1:z/mSSVyLFGZzqb3ZIKojjyqIx/xbmz/UHdCSv9HmqXY=
|
||||
mvdan.cc/unparam v0.0.0-20240917084806-57a3b4290ba3 h1:YkmTN1n5U60NM02j7TCSWRlW3fqNiuXe/eVXf0dLFN8=
|
||||
mvdan.cc/unparam v0.0.0-20240917084806-57a3b4290ba3/go.mod h1:z5yboO1sP1Q9pcfvS597TpfbNXQjphDlkCJHzt13ybc=
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
_ "github.com/fzipp/gocyclo/cmd/gocyclo"
|
||||
_ "github.com/golangci/misspell/cmd/misspell"
|
||||
_ "github.com/gordonklaus/ineffassign"
|
||||
_ "github.com/jstemmer/go-junit-report/v2"
|
||||
_ "github.com/kisielk/errcheck"
|
||||
_ "github.com/kyoh86/looppointer"
|
||||
_ "github.com/securego/gosec/v2/cmd/gosec"
|
||||
_ "github.com/uudashr/gocognit/cmd/gocognit"
|
||||
_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
|
||||
@@ -19,5 +19,6 @@ import (
|
||||
_ "golang.org/x/vuln/cmd/govulncheck"
|
||||
_ "honnef.co/go/tools/cmd/staticcheck"
|
||||
_ "mvdan.cc/gofumpt"
|
||||
_ "mvdan.cc/sh/v3/cmd/shfmt"
|
||||
_ "mvdan.cc/unparam"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/exp/maps"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
@@ -28,8 +29,9 @@ type VersionInfo struct {
|
||||
CanAutoUpdate aghalg.NullBool `json:"can_autoupdate,omitempty"`
|
||||
}
|
||||
|
||||
// MaxResponseSize is responses on server's requests maximum length in bytes.
|
||||
const MaxResponseSize = 64 * 1024
|
||||
// maxVersionRespSize is the maximum length in bytes for version information
|
||||
// response.
|
||||
const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB
|
||||
|
||||
// VersionInfo downloads the latest version information. If forceRecheck is
|
||||
// false and there are cached results, those results are returned.
|
||||
@@ -51,7 +53,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
r := ioutil.LimitReader(resp.Body, MaxResponseSize)
|
||||
r := ioutil.LimitReader(resp.Body, maxVersionRespSize.Bytes())
|
||||
|
||||
// This use of ReadAll is safe, because we just limited the appropriate
|
||||
// ReadCloser.
|
||||
@@ -120,8 +122,8 @@ func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string,
|
||||
return dlURL, key, true
|
||||
}
|
||||
|
||||
keys := maps.Keys(versionObj)
|
||||
slices.Sort(keys)
|
||||
keys := slices.Sorted(maps.Keys(versionObj))
|
||||
|
||||
log.Error("updater: key %q not found; got keys %q", key, keys)
|
||||
|
||||
return "", key, false
|
||||
|
||||
@@ -51,9 +51,11 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
fakeURL, err := url.JoinPath(srv.URL, "adguardhome", version.ChannelBeta, "version.json")
|
||||
srvURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Version: "v0.103.0-beta.1",
|
||||
@@ -134,7 +136,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
GOARCH: tc.arch,
|
||||
GOARM: tc.arm,
|
||||
GOMIPS: tc.mips,
|
||||
VersionCheckURL: fakeURL.String(),
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
info, err := u.VersionInfo(false)
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
)
|
||||
|
||||
// Updater is the AdGuard Home updater.
|
||||
@@ -61,10 +64,23 @@ type Updater struct {
|
||||
prevCheckResult VersionInfo
|
||||
}
|
||||
|
||||
// DefaultVersionURL returns the default URL for the version announcement.
|
||||
func DefaultVersionURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: urlutil.SchemeHTTPS,
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
||||
}
|
||||
}
|
||||
|
||||
// Config is the AdGuard Home updater configuration.
|
||||
type Config struct {
|
||||
Client *http.Client
|
||||
|
||||
// VersionCheckURL is URL to the latest version announcement. It must not
|
||||
// be nil, see [DefaultVersionURL].
|
||||
VersionCheckURL *url.URL
|
||||
|
||||
Version string
|
||||
Channel string
|
||||
GOARCH string
|
||||
@@ -81,12 +97,9 @@ type Config struct {
|
||||
|
||||
// ExecPath is path to the executable file.
|
||||
ExecPath string
|
||||
|
||||
// VersionCheckURL is url to the latest version announcement.
|
||||
VersionCheckURL string
|
||||
}
|
||||
|
||||
// NewUpdater creates a new Updater.
|
||||
// NewUpdater creates a new Updater. conf must not be nil.
|
||||
func NewUpdater(conf *Config) *Updater {
|
||||
return &Updater{
|
||||
client: conf.Client,
|
||||
@@ -101,7 +114,7 @@ func NewUpdater(conf *Config) *Updater {
|
||||
confName: conf.ConfName,
|
||||
workDir: conf.WorkDir,
|
||||
execPath: conf.ExecPath,
|
||||
versionCheckURL: conf.VersionCheckURL,
|
||||
versionCheckURL: conf.VersionCheckURL.String(),
|
||||
|
||||
mu: &sync.RWMutex{},
|
||||
}
|
||||
@@ -167,14 +180,6 @@ func (u *Updater) NewVersion() (nv string) {
|
||||
return u.newVersion
|
||||
}
|
||||
|
||||
// VersionCheckURL returns the version check URL.
|
||||
func (u *Updater) VersionCheckURL() (vcu string) {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
return u.versionCheckURL
|
||||
}
|
||||
|
||||
// prepare fills all necessary fields in Updater object.
|
||||
func (u *Updater) prepare() (err error) {
|
||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||
@@ -265,7 +270,7 @@ func (u *Updater) check() (err error) {
|
||||
// ignores the configuration file if firstRun is true.
|
||||
func (u *Updater) backup(firstRun bool) (err error) {
|
||||
log.Debug("updater: backing up current configuration")
|
||||
_ = aghos.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||
if !firstRun {
|
||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
@@ -339,10 +344,10 @@ func (u *Updater) downloadPackageFile() (err error) {
|
||||
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
||||
}
|
||||
|
||||
_ = aghos.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||
|
||||
log.Debug("updater: saving package to file")
|
||||
err = aghos.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing package file: %w", err)
|
||||
}
|
||||
@@ -355,7 +360,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outDir, name)
|
||||
outName := filepath.Join(outDir, name)
|
||||
|
||||
if hdr.Typeflag == tar.TypeDir {
|
||||
if name == "AdGuardHome" {
|
||||
@@ -367,12 +372,12 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = aghos.Mkdir(outputName, os.FileMode(hdr.Mode&0o755))
|
||||
err = os.Mkdir(outName, os.FileMode(hdr.Mode&0o755))
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||
return "", fmt.Errorf("creating directory %q: %w", outName, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
log.Debug("updater: created directory %q", outName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -384,13 +389,9 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = aghos.OpenFile(
|
||||
outputName,
|
||||
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
|
||||
os.FileMode(hdr.Mode&0o755),
|
||||
)
|
||||
wc, err = os.OpenFile(outName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(hdr.Mode)&0o755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
|
||||
return "", fmt.Errorf("os.OpenFile(%s): %w", outName, err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
|
||||
|
||||
@@ -399,7 +400,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created file %q", outputName)
|
||||
log.Debug("updater: created file %q", outName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
@@ -469,7 +470,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = aghos.Mkdir(outputName, fi.Mode())
|
||||
err = os.Mkdir(outputName, fi.Mode())
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||
}
|
||||
@@ -480,7 +481,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = aghos.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(): %w", err)
|
||||
}
|
||||
@@ -530,7 +531,7 @@ func copyFile(src, dst string, perm fs.FileMode) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = aghos.WriteFile(dst, d, perm)
|
||||
err = os.WriteFile(dst, d, perm)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -45,7 +46,7 @@ func TestUpdater_internal(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
exePath := filepath.Join(wd, tc.exeName)
|
||||
|
||||
// start server for returning package file
|
||||
// Start server for returning package file.
|
||||
pkgData, err := os.ReadFile(filepath.Join("testdata", tc.archiveName))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -59,6 +60,9 @@ func TestUpdater_internal(t *testing.T) {
|
||||
ExecPath: exePath,
|
||||
WorkDir: wd,
|
||||
ConfName: yamlPath,
|
||||
// TODO(e.burkov): Rewrite the test to use a fake version check
|
||||
// URL with a fake URLs for the package files.
|
||||
VersionCheckURL: &url.URL{},
|
||||
})
|
||||
|
||||
u.newVersion = "v0.103.1"
|
||||
@@ -72,36 +76,40 @@ func TestUpdater_internal(t *testing.T) {
|
||||
|
||||
u.clean()
|
||||
|
||||
// check backup files
|
||||
d, err := os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
require.True(t, t.Run("backup", func(t *testing.T) {
|
||||
var d []byte
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", "AdGuardHome.yaml"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(filepath.Join(wd, "agh-backup", tc.exeName))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.exeName, string(d))
|
||||
assert.Equal(t, tc.exeName, string(d))
|
||||
}))
|
||||
|
||||
// check updated files
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
require.True(t, t.Run("updated", func(t *testing.T) {
|
||||
var d []byte
|
||||
d, err = os.ReadFile(exePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1", string(d))
|
||||
assert.Equal(t, "1", string(d))
|
||||
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(readmePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2", string(d))
|
||||
assert.Equal(t, "2", string(d))
|
||||
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(licensePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "3", string(d))
|
||||
assert.Equal(t, "3", string(d))
|
||||
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
d, err = os.ReadFile(yamlPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
assert.Equal(t, "AdGuardHome.yaml", string(d))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +65,10 @@ func TestUpdater_Update(t *testing.T) {
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
versionCheckURL, err := url.JoinPath(srv.URL, versionPath)
|
||||
srvURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
versionCheckURL := srvURL.JoinPath(versionPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
|
||||
Reference in New Issue
Block a user