all: sync with master

This commit is contained in:
Eugene Burkov
2024-10-29 18:53:56 +03:00
parent 2aaf8ab3c1
commit 6affa96490
51 changed files with 1361 additions and 456 deletions

View File

@@ -0,0 +1,50 @@
package aghos
import (
"io/fs"
"os"
)
// TODO(e.burkov): Add platform-independent tests.
// Chmod is an extension for [os.Chmod] that properly handles Windows access
// rights.
func Chmod(name string, perm fs.FileMode) (err error) {
return chmod(name, perm)
}
// Mkdir is an extension for [os.Mkdir] that properly handles Windows access
// rights.
func Mkdir(name string, perm fs.FileMode) (err error) {
return mkdir(name, perm)
}
// MkdirAll is an extension for [os.MkdirAll] that properly handles Windows
// access rights.
func MkdirAll(path string, perm fs.FileMode) (err error) {
return mkdirAll(path, perm)
}
// WriteFile is an extension for [os.WriteFile] that properly handles Windows
// access rights.
func WriteFile(filename string, data []byte, perm fs.FileMode) (err error) {
return writeFile(filename, data, perm)
}
// OpenFile is an extension for [os.OpenFile] that properly handles Windows
// access rights.
func OpenFile(name string, flag int, perm fs.FileMode) (file *os.File, err error) {
return openFile(name, flag, perm)
}
// Stat is an extension for [os.Stat] that properly handles Windows access
// rights.
//
// Note that on Windows the "other" permission bits combines the access rights
// of any trustee that is neither the owner nor the owning group for the file.
//
// TODO(e.burkov): Inspect the behavior for the World (everyone) well-known
// SID and, perhaps, use it.
func Stat(name string) (fi fs.FileInfo, err error) {
return stat(name)
}

View File

@@ -0,0 +1,42 @@
//go:build unix
package aghos
import (
"io/fs"
"os"
"github.com/google/renameio/v2/maybe"
)
// chmod is a Unix implementation of [Chmod].
func chmod(name string, perm fs.FileMode) (err error) {
return os.Chmod(name, perm)
}
// mkdir is a Unix implementation of [Mkdir].
func mkdir(name string, perm fs.FileMode) (err error) {
return os.Mkdir(name, perm)
}
// mkdirAll is a Unix implementation of [MkdirAll].
func mkdirAll(path string, perm fs.FileMode) (err error) {
return os.MkdirAll(path, perm)
}
// writeFile is a Unix implementation of [WriteFile].
func writeFile(filename string, data []byte, perm fs.FileMode) (err error) {
return maybe.WriteFile(filename, data, perm)
}
// openFile is a Unix implementation of [OpenFile].
func openFile(name string, flag int, perm fs.FileMode) (file *os.File, err error) {
// #nosec G304 -- This function simply wraps the [os.OpenFile] function, so
// the security concerns should be addressed to the [OpenFile] calls.
return os.OpenFile(name, flag, perm)
}
// stat is a Unix implementation of [Stat].
func stat(name string) (fi os.FileInfo, err error) {
return os.Stat(name)
}

View File

@@ -0,0 +1,392 @@
//go:build windows
package aghos
import (
"fmt"
"io/fs"
"os"
"path/filepath"
"unsafe"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/windows"
)
// fileInfo is a Windows implementation of [fs.FileInfo], that contains the
// filemode converted from the security descriptor.
type fileInfo struct {
// fs.FileInfo is embedded to provide the default implementations and data
// successfully retrieved by [os.Stat].
fs.FileInfo
// mode is the file mode converted from the security descriptor.
mode fs.FileMode
}
// type check
var _ fs.FileInfo = (*fileInfo)(nil)
// Mode implements [fs.FileInfo.Mode] for [*fileInfo].
func (fi *fileInfo) Mode() (mode fs.FileMode) { return fi.mode }
// stat is a Windows implementation of [Stat].
func stat(name string) (fi os.FileInfo, err error) {
absName, err := filepath.Abs(name)
if err != nil {
return nil, fmt.Errorf("computing absolute path: %w", err)
}
fi, err = os.Stat(absName)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, err
}
dacl, owner, group, err := retrieveDACL(absName)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, err
}
var ownerMask, groupMask, otherMask windows.ACCESS_MASK
for i := range uint32(dacl.AceCount) {
var ace *windows.ACCESS_ALLOWED_ACE
err = windows.GetAce(dacl, i, &ace)
if err != nil {
return nil, fmt.Errorf("getting access control entry at index %d: %w", i, err)
}
entrySid := (*windows.SID)(unsafe.Pointer(&ace.SidStart))
switch {
case entrySid.Equals(owner):
ownerMask |= ace.Mask
case entrySid.Equals(group):
groupMask |= ace.Mask
default:
otherMask |= ace.Mask
}
}
mode := fi.Mode()
perm := masksToPerm(ownerMask, groupMask, otherMask, mode.IsDir())
return &fileInfo{
FileInfo: fi,
// Use the file mode from the security descriptor, but use the
// calculated permission bits.
mode: perm | mode&^fs.FileMode(0o777),
}, nil
}
// retrieveDACL retrieves the discretionary access control list, owner, and
// group from the security descriptor of the file with the specified absolute
// name.
func retrieveDACL(absName string) (dacl *windows.ACL, owner, group *windows.SID, err error) {
// desiredSecInfo defines the parts of a security descriptor to retrieve.
const desiredSecInfo windows.SECURITY_INFORMATION = windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION |
windows.UNPROTECTED_DACL_SECURITY_INFORMATION
sd, err := windows.GetNamedSecurityInfo(absName, windows.SE_FILE_OBJECT, desiredSecInfo)
if err != nil {
return nil, nil, nil, fmt.Errorf("getting security descriptor: %w", err)
}
dacl, _, err = sd.DACL()
if err != nil {
return nil, nil, nil, fmt.Errorf("getting discretionary access control list: %w", err)
}
owner, _, err = sd.Owner()
if err != nil {
return nil, nil, nil, fmt.Errorf("getting owner sid: %w", err)
}
group, _, err = sd.Group()
if err != nil {
return nil, nil, nil, fmt.Errorf("getting group sid: %w", err)
}
return dacl, owner, group, nil
}
// chmod is a Windows implementation of [Chmod].
func chmod(name string, perm fs.FileMode) (err error) {
fi, err := os.Stat(name)
if err != nil {
return fmt.Errorf("getting file info: %w", err)
}
entries := make([]windows.EXPLICIT_ACCESS, 0, 3)
creatorMask, groupMask, worldMask := permToMasks(perm, fi.IsDir())
sidMasks := []struct {
Key windows.WELL_KNOWN_SID_TYPE
Value windows.ACCESS_MASK
}{{
Key: windows.WinCreatorOwnerSid,
Value: creatorMask,
}, {
Key: windows.WinCreatorGroupSid,
Value: groupMask,
}, {
Key: windows.WinWorldSid,
Value: worldMask,
}}
var errs []error
for _, sidMask := range sidMasks {
if sidMask.Value == 0 {
continue
}
var trustee windows.TRUSTEE
trustee, err = newWellKnownTrustee(sidMask.Key)
if err != nil {
errs = append(errs, err)
continue
}
entries = append(entries, windows.EXPLICIT_ACCESS{
AccessPermissions: sidMask.Value,
AccessMode: windows.GRANT_ACCESS,
Inheritance: windows.NO_INHERITANCE,
Trustee: trustee,
})
}
if err = errors.Join(errs...); err != nil {
return fmt.Errorf("creating access control entries: %w", err)
}
acl, err := windows.ACLFromEntries(entries, nil)
if err != nil {
return fmt.Errorf("creating access control list: %w", err)
}
// secInfo defines the parts of a security descriptor to set.
const secInfo windows.SECURITY_INFORMATION = windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION
err = windows.SetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, secInfo, nil, nil, acl, nil)
if err != nil {
return fmt.Errorf("setting security descriptor: %w", err)
}
return nil
}
// mkdir is a Windows implementation of [Mkdir].
//
// TODO(e.burkov): Consider using [windows.CreateDirectory] instead of
// [os.Mkdir] to reduce the number of syscalls.
func mkdir(name string, perm os.FileMode) (err error) {
name, err = filepath.Abs(name)
if err != nil {
return fmt.Errorf("computing absolute path: %w", err)
}
err = os.Mkdir(name, perm)
if err != nil {
return fmt.Errorf("creating directory: %w", err)
}
defer func() {
if err != nil {
err = errors.WithDeferred(err, os.Remove(name))
}
}()
return chmod(name, perm)
}
// mkdirAll is a Windows implementation of [MkdirAll].
func mkdirAll(path string, perm os.FileMode) (err error) {
parent, _ := filepath.Split(path)
if parent != "" {
err = os.MkdirAll(parent, perm)
if err != nil && !errors.Is(err, os.ErrExist) {
return fmt.Errorf("creating parent directories: %w", err)
}
}
err = mkdir(path, perm)
if errors.Is(err, os.ErrExist) {
return nil
}
return err
}
// writeFile is a Windows implementation of [WriteFile].
func writeFile(filename string, data []byte, perm os.FileMode) (err error) {
file, err := openFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
if err != nil {
return fmt.Errorf("opening file: %w", err)
}
defer func() { err = errors.WithDeferred(err, file.Close()) }()
_, err = file.Write(data)
if err != nil {
return fmt.Errorf("writing data: %w", err)
}
return nil
}
// openFile is a Windows implementation of [OpenFile].
func openFile(name string, flag int, perm os.FileMode) (file *os.File, err error) {
// Only change permissions if the file not yet exists, but should be
// created.
if flag&os.O_CREATE == 0 {
return os.OpenFile(name, flag, perm)
}
_, err = stat(name)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
defer func() { err = errors.WithDeferred(err, chmod(name, perm)) }()
} else {
return nil, fmt.Errorf("getting file info: %w", err)
}
}
return os.OpenFile(name, flag, perm)
}
// newWellKnownTrustee returns a trustee for a well-known SID.
func newWellKnownTrustee(stype windows.WELL_KNOWN_SID_TYPE) (t windows.TRUSTEE, err error) {
sid, err := windows.CreateWellKnownSid(stype)
if err != nil {
return windows.TRUSTEE{}, fmt.Errorf("creating sid for type %d: %w", stype, err)
}
return windows.TRUSTEE{
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeValue: windows.TrusteeValueFromSID(sid),
}, nil
}
// UNIX file mode permission bits.
const (
permRead = 0b100
permWrite = 0b010
permExecute = 0b001
)
// Windows access masks for appropriate UNIX file mode permission bits and
// file types.
const (
fileReadRights windows.ACCESS_MASK = windows.READ_CONTROL |
windows.FILE_READ_DATA |
windows.FILE_READ_ATTRIBUTES |
windows.FILE_READ_EA |
windows.SYNCHRONIZE |
windows.ACCESS_SYSTEM_SECURITY
fileWriteRights windows.ACCESS_MASK = windows.WRITE_DAC |
windows.WRITE_OWNER |
windows.FILE_WRITE_DATA |
windows.FILE_WRITE_ATTRIBUTES |
windows.FILE_WRITE_EA |
windows.DELETE |
windows.FILE_APPEND_DATA |
windows.SYNCHRONIZE |
windows.ACCESS_SYSTEM_SECURITY
fileExecuteRights windows.ACCESS_MASK = windows.FILE_EXECUTE
dirReadRights windows.ACCESS_MASK = windows.READ_CONTROL |
windows.FILE_LIST_DIRECTORY |
windows.FILE_READ_EA |
windows.FILE_READ_ATTRIBUTES<<1 |
windows.SYNCHRONIZE |
windows.ACCESS_SYSTEM_SECURITY
dirWriteRights windows.ACCESS_MASK = windows.WRITE_DAC |
windows.WRITE_OWNER |
windows.DELETE |
windows.FILE_WRITE_DATA |
windows.FILE_APPEND_DATA |
windows.FILE_WRITE_EA |
windows.FILE_WRITE_ATTRIBUTES<<1 |
windows.SYNCHRONIZE |
windows.ACCESS_SYSTEM_SECURITY
dirExecuteRights windows.ACCESS_MASK = windows.FILE_TRAVERSE
)
// permToMasks converts a UNIX file mode permissions to the corresponding
// Windows access masks. The [isDir] argument is used to set specific access
// bits for directories.
func permToMasks(fm os.FileMode, isDir bool) (owner, group, world windows.ACCESS_MASK) {
mask := fm.Perm()
owner = permToMask(byte((mask>>6)&0b111), isDir)
group = permToMask(byte((mask>>3)&0b111), isDir)
world = permToMask(byte(mask&0b111), isDir)
return owner, group, world
}
// permToMask converts a UNIX file mode permission bits within p byte to the
// corresponding Windows access mask. The [isDir] argument is used to set
// specific access bits for directories.
func permToMask(p byte, isDir bool) (mask windows.ACCESS_MASK) {
readRights, writeRights, executeRights := fileReadRights, fileWriteRights, fileExecuteRights
if isDir {
readRights, writeRights, executeRights = dirReadRights, dirWriteRights, dirExecuteRights
}
if p&permRead != 0 {
mask |= readRights
}
if p&permWrite != 0 {
mask |= writeRights
}
if p&permExecute != 0 {
mask |= executeRights
}
return mask
}
// masksToPerm converts Windows access masks to the corresponding UNIX file
// mode permission bits.
func masksToPerm(u, g, o windows.ACCESS_MASK, isDir bool) (perm fs.FileMode) {
perm |= fs.FileMode(maskToPerm(u, isDir)) << 6
perm |= fs.FileMode(maskToPerm(g, isDir)) << 3
perm |= fs.FileMode(maskToPerm(o, isDir))
return perm
}
// maskToPerm converts a Windows access mask to the corresponding UNIX file
// mode permission bits.
func maskToPerm(mask windows.ACCESS_MASK, isDir bool) (perm byte) {
readMask, writeMask, executeMask := fileReadRights, fileWriteRights, fileExecuteRights
if isDir {
readMask, writeMask, executeMask = dirReadRights, dirWriteRights, dirExecuteRights
}
// Remove common bits to avoid false positive detection of unset rights.
readMask ^= windows.SYNCHRONIZE | windows.ACCESS_SYSTEM_SECURITY
writeMask ^= windows.SYNCHRONIZE | windows.ACCESS_SYSTEM_SECURITY
if mask&readMask != 0 {
perm |= permRead
}
if mask&writeMask != 0 {
perm |= permWrite
}
if mask&executeMask != 0 {
perm |= permExecute
}
return perm
}

View File

@@ -0,0 +1,135 @@
//go:build windows
package aghos
import (
"io/fs"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/windows"
)
func TestPermToMasks(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
perm fs.FileMode
wantUser windows.ACCESS_MASK
wantGroup windows.ACCESS_MASK
wantOther windows.ACCESS_MASK
isDir bool
}{{
name: "all",
perm: 0b111_111_111,
wantUser: fileReadRights | fileWriteRights | fileExecuteRights,
wantGroup: fileReadRights | fileWriteRights | fileExecuteRights,
wantOther: fileReadRights | fileWriteRights | fileExecuteRights,
isDir: false,
}, {
name: "user_write",
perm: 0b010_000_000,
wantUser: fileWriteRights,
wantGroup: 0,
wantOther: 0,
isDir: false,
}, {
name: "group_read",
perm: 0b000_100_000,
wantUser: 0,
wantGroup: fileReadRights,
wantOther: 0,
isDir: false,
}, {
name: "all_dir",
perm: 0b111_111_111,
wantUser: dirReadRights | dirWriteRights | dirExecuteRights,
wantGroup: dirReadRights | dirWriteRights | dirExecuteRights,
wantOther: dirReadRights | dirWriteRights | dirExecuteRights,
isDir: true,
}, {
name: "user_write_dir",
perm: 0b010_000_000,
wantUser: dirWriteRights,
wantGroup: 0,
wantOther: 0,
isDir: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
user, group, other := permToMasks(tc.perm, tc.isDir)
assert.Equal(t, tc.wantUser, user)
assert.Equal(t, tc.wantGroup, group)
assert.Equal(t, tc.wantOther, other)
})
}
}
func TestMasksToPerm(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
user windows.ACCESS_MASK
group windows.ACCESS_MASK
other windows.ACCESS_MASK
wantPerm fs.FileMode
isDir bool
}{{
name: "all",
user: fileReadRights | fileWriteRights | fileExecuteRights,
group: fileReadRights | fileWriteRights | fileExecuteRights,
other: fileReadRights | fileWriteRights | fileExecuteRights,
wantPerm: 0b111_111_111,
isDir: false,
}, {
name: "user_write",
user: fileWriteRights,
group: 0,
other: 0,
wantPerm: 0b010_000_000,
isDir: false,
}, {
name: "group_read",
user: 0,
group: fileReadRights,
other: 0,
wantPerm: 0b000_100_000,
isDir: false,
}, {
name: "no_access",
user: 0,
group: 0,
other: 0,
wantPerm: 0,
isDir: false,
}, {
name: "all_dir",
user: dirReadRights | dirWriteRights | dirExecuteRights,
group: dirReadRights | dirWriteRights | dirExecuteRights,
other: dirReadRights | dirWriteRights | dirExecuteRights,
wantPerm: 0b111_111_111,
isDir: true,
}, {
name: "user_write_dir",
user: dirWriteRights,
group: 0,
other: 0,
wantPerm: 0b010_000_000,
isDir: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Don't call [fs.FileMode.Perm] since the result is expected to
// contain only the permission bits.
assert.Equal(t, tc.wantPerm, masksToPerm(tc.user, tc.group, tc.other, tc.isDir))
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
)
@@ -62,7 +63,7 @@ func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error
return nil, fmt.Errorf("opening pending file: %w", err)
}
err = file.Chmod(mode)
err = aghos.Chmod(file.Name(), mode)
if err != nil {
return nil, fmt.Errorf("preparing pending file: %w", err)
}

View File

@@ -89,14 +89,14 @@ func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
// AddressProcessor is a fake [client.AddressProcessor] implementation for
// tests.
type AddressProcessor struct {
OnProcess func(ip netip.Addr)
OnProcess func(ctx context.Context, ip netip.Addr)
OnClose func() (err error)
}
// Process implements the [client.AddressProcessor] interface for
// *AddressProcessor.
func (p *AddressProcessor) Process(ip netip.Addr) {
p.OnProcess(ip)
func (p *AddressProcessor) Process(ctx context.Context, ip netip.Addr) {
p.OnProcess(ctx, ip)
}
// Close implements the [client.AddressProcessor] interface for
@@ -107,13 +107,18 @@ func (p *AddressProcessor) Close() (err error) {
// AddressUpdater is a fake [client.AddressUpdater] implementation for tests.
type AddressUpdater struct {
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
OnUpdateAddress func(ctx context.Context, ip netip.Addr, host string, info *whois.Info)
}
// UpdateAddress implements the [client.AddressUpdater] interface for
// *AddressUpdater.
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
p.OnUpdateAddress(ip, host, info)
func (p *AddressUpdater) UpdateAddress(
ctx context.Context,
ip netip.Addr,
host string,
info *whois.Info,
) {
p.OnUpdateAddress(ctx, ip, host, info)
}
// Package dnsforward

View File

@@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
)
@@ -22,7 +21,7 @@ const ErrClosed errors.Error = "use of closed address processor"
// AddressProcessor is the interface for types that can process clients.
type AddressProcessor interface {
Process(ip netip.Addr)
Process(ctx context.Context, ip netip.Addr)
Close() (err error)
}
@@ -33,7 +32,7 @@ type EmptyAddrProc struct{}
var _ AddressProcessor = EmptyAddrProc{}
// Process implements the [AddressProcessor] interface for EmptyAddrProc.
func (EmptyAddrProc) Process(_ netip.Addr) {}
func (EmptyAddrProc) Process(_ context.Context, _ netip.Addr) {}
// Close implements the [AddressProcessor] interface for EmptyAddrProc.
func (EmptyAddrProc) Close() (_ error) { return nil }
@@ -90,12 +89,15 @@ type DefaultAddrProcConfig struct {
type AddressUpdater interface {
// UpdateAddress updates information about an IP address, setting host (if
// not empty) and WHOIS information (if not nil).
UpdateAddress(ip netip.Addr, host string, info *whois.Info)
UpdateAddress(ctx context.Context, ip netip.Addr, host string, info *whois.Info)
}
// DefaultAddrProc processes incoming client addresses with rDNS and WHOIS, if
// configured, and updates that information in a client storage.
type DefaultAddrProc struct {
// logger is used to log the operation of address processor.
logger *slog.Logger
// clientIPsMu serializes closure of clientIPs and access to isClosed.
clientIPsMu *sync.Mutex
@@ -142,6 +144,7 @@ const (
// not be nil.
func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
p = &DefaultAddrProc{
logger: c.BaseLogger.With(slogutil.KeyPrefix, "addrproc"),
clientIPsMu: &sync.Mutex{},
clientIPs: make(chan netip.Addr, defaultQueueSize),
rdns: &rdns.Empty{},
@@ -164,10 +167,13 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext)
}
go p.process(c.CatchPanics)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
go p.process(ctx, c.CatchPanics)
for _, ip := range c.InitialAddresses {
p.Process(ip)
p.Process(ctx, ip)
}
return p
@@ -210,7 +216,7 @@ func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Int
var _ AddressProcessor = (*DefaultAddrProc)(nil)
// Process implements the [AddressProcessor] interface for *DefaultAddrProc.
func (p *DefaultAddrProc) Process(ip netip.Addr) {
func (p *DefaultAddrProc) Process(ctx context.Context, ip netip.Addr) {
p.clientIPsMu.Lock()
defer p.clientIPsMu.Unlock()
@@ -222,38 +228,42 @@ func (p *DefaultAddrProc) Process(ip netip.Addr) {
case p.clientIPs <- ip:
// Go on.
default:
log.Debug("clients: ip channel is full; len: %d", len(p.clientIPs))
p.logger.DebugContext(ctx, "ip channel is full", "len", len(p.clientIPs))
}
}
// process processes the incoming client IP-address information. It is intended
// to be used as a goroutine. Once clientIPs is closed, process exits.
func (p *DefaultAddrProc) process(catchPanics bool) {
func (p *DefaultAddrProc) process(ctx context.Context, catchPanics bool) {
if catchPanics {
defer log.OnPanic("addrProcessor.process")
defer slogutil.RecoverAndLog(ctx, p.logger)
}
log.Info("clients: processing addresses")
ctx := context.TODO()
p.logger.InfoContext(ctx, "processing addresses")
for ip := range p.clientIPs {
host := p.processRDNS(ctx, ip)
info := p.processWHOIS(ctx, ip)
p.addrUpdater.UpdateAddress(ip, host, info)
p.addrUpdater.UpdateAddress(ctx, ip, host, info)
}
log.Info("clients: finished processing addresses")
p.logger.InfoContext(ctx, "finished processing addresses")
}
// processRDNS resolves the clients' IP addresses using reverse DNS. host is
// empty if there were errors or if the information hasn't changed.
func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) {
start := time.Now()
log.Debug("clients: processing %s with rdns", ip)
p.logger.DebugContext(ctx, "processing rdns", "ip", ip)
defer func() {
log.Debug("clients: finished processing %s with rdns in %s", ip, time.Since(start))
p.logger.DebugContext(
ctx,
"finished processing rdns",
"ip", ip,
"host", host,
"elapsed", time.Since(start),
)
}()
ok := p.shouldResolve(ip)
@@ -280,9 +290,15 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
// hasn't changed.
func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) {
start := time.Now()
log.Debug("clients: processing %s with whois", ip)
p.logger.DebugContext(ctx, "processing whois", "ip", ip)
defer func() {
log.Debug("clients: finished processing %s with whois in %s", ip, time.Since(start))
p.logger.DebugContext(
ctx,
"finished processing whois",
"ip", ip,
"whois", info,
"elapsed", time.Since(start),
)
}()
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the

View File

@@ -26,7 +26,8 @@ func TestEmptyAddrProc(t *testing.T) {
p := client.EmptyAddrProc{}
assert.NotPanics(t, func() {
p.Process(testIP)
ctx := testutil.ContextWithTimeout(t, testTimeout)
p.Process(ctx, testIP)
})
assert.NotPanics(t, func() {
@@ -120,7 +121,8 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
})
testutil.CleanupAndRequireSuccess(t, p.Close)
p.Process(tc.ip)
ctx := testutil.ContextWithTimeout(t, testTimeout)
p.Process(ctx, tc.ip)
if !tc.wantUpd {
return
@@ -146,8 +148,8 @@ func newOnUpdateAddress(
ips chan<- netip.Addr,
hosts chan<- string,
infos chan<- *whois.Info,
) (f func(ip netip.Addr, host string, info *whois.Info)) {
return func(ip netip.Addr, host string, info *whois.Info) {
) (f func(ctx context.Context, ip netip.Addr, host string, info *whois.Info)) {
return func(ctx context.Context, ip netip.Addr, host string, info *whois.Info) {
if !want && (host != "" || info != nil) {
panic(fmt.Errorf("got unexpected update for %v with %q and %v", ip, host, info))
}
@@ -230,7 +232,8 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
})
testutil.CleanupAndRequireSuccess(t, p.Close)
p.Process(testIP)
ctx := testutil.ContextWithTimeout(t, testTimeout)
p.Process(ctx, testIP)
if !tc.wantUpd {
return
@@ -251,7 +254,9 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
func TestDefaultAddrProc_Close(t *testing.T) {
t.Parallel()
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{})
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
BaseLogger: slogutil.NewDiscardLogger(),
})
err := p.Close()
assert.NoError(t, err)

View File

@@ -1,20 +1,20 @@
package client
import (
"context"
"encoding"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/uuid"
)
@@ -136,7 +136,7 @@ type Persistent struct {
// validate returns an error if persistent client information contains errors.
// allTags must be sorted.
func (c *Persistent) validate(allTags []string) (err error) {
func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []string) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
@@ -153,7 +153,7 @@ func (c *Persistent) validate(allTags []string) (err error) {
err = conf.Close()
if err != nil {
log.Error("client: closing upstream config: %s", err)
l.ErrorContext(ctx, "client: closing upstream config", slogutil.KeyError, err)
}
for _, t := range c.Tags {
@@ -323,20 +323,3 @@ func (c *Persistent) CloseUpstreams() (err error) {
return nil
}
// SetSafeSearch initializes and sets the safe search filter for this client.
func (c *Persistent) SetSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}

View File

@@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
@@ -15,6 +16,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// allowedTags is the list of available client tags.
@@ -83,6 +85,10 @@ type HostsContainer interface {
// StorageConfig is the client storage configuration structure.
type StorageConfig struct {
// Logger is used for logging the operation of the client storage. It must
// not be nil.
Logger *slog.Logger
// DHCP is used to match IPs against MACs of persistent clients and update
// [SourceDHCP] runtime client information. It must not be nil.
DHCP DHCP
@@ -108,6 +114,10 @@ type StorageConfig struct {
// Storage contains information about persistent and runtime clients.
type Storage struct {
// logger is used for logging the operation of the client storage. It must
// not be nil.
logger *slog.Logger
// mu protects indexes of persistent and runtime clients.
mu *sync.Mutex
@@ -145,12 +155,12 @@ type Storage struct {
}
// NewStorage returns initialized client storage. conf must not be nil.
func NewStorage(conf *StorageConfig) (s *Storage, err error) {
func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error) {
tags := slices.Clone(allowedTags)
slices.Sort(tags)
s = &Storage{
allowedTags: tags,
logger: conf.Logger,
mu: &sync.Mutex{},
index: newIndex(),
runtimeIndex: newRuntimeIndex(),
@@ -158,18 +168,19 @@ func NewStorage(conf *StorageConfig) (s *Storage, err error) {
etcHosts: conf.EtcHosts,
arpDB: conf.ARPDB,
done: make(chan struct{}),
allowedTags: tags,
arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod,
runtimeSourceDHCP: conf.RuntimeSourceDHCP,
}
for i, p := range conf.InitialClients {
err = s.Add(p)
err = s.Add(ctx, p)
if err != nil {
return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err)
}
}
s.ReloadARP()
s.ReloadARP(ctx)
return s, nil
}
@@ -177,9 +188,9 @@ func NewStorage(conf *StorageConfig) (s *Storage, err error) {
// Start starts the goroutines for updating the runtime client information.
//
// TODO(s.chzhen): Pass context.
func (s *Storage) Start(_ context.Context) (err error) {
go s.periodicARPUpdate()
go s.handleHostsUpdates()
func (s *Storage) Start(ctx context.Context) (err error) {
go s.periodicARPUpdate(ctx)
go s.handleHostsUpdates(ctx)
return nil
}
@@ -195,15 +206,15 @@ func (s *Storage) Shutdown(_ context.Context) (err error) {
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
// intended to be used as a goroutine.
func (s *Storage) periodicARPUpdate() {
defer log.OnPanic("storage")
func (s *Storage) periodicARPUpdate(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, s.logger)
t := time.NewTicker(s.arpClientsUpdatePeriod)
for {
select {
case <-t.C:
s.ReloadARP()
s.ReloadARP(ctx)
case <-s.done:
return
}
@@ -211,28 +222,28 @@ func (s *Storage) periodicARPUpdate() {
}
// ReloadARP reloads runtime clients from ARP, if configured.
func (s *Storage) ReloadARP() {
func (s *Storage) ReloadARP(ctx context.Context) {
if s.arpDB != nil {
s.addFromSystemARP()
s.addFromSystemARP(ctx)
}
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (s *Storage) addFromSystemARP() {
func (s *Storage) addFromSystemARP(ctx context.Context) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.arpDB.Refresh(); err != nil {
s.arpDB = arpdb.Empty{}
log.Error("refreshing arp container: %s", err)
s.logger.ErrorContext(ctx, "refreshing arp container", slogutil.KeyError, err)
return
}
ns := s.arpDB.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
s.logger.DebugContext(ctx, "refreshing arp container: the update is empty")
return
}
@@ -246,17 +257,22 @@ func (s *Storage) addFromSystemARP() {
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed)
s.logger.DebugContext(
ctx,
"updating client aliases from arp neighborhood",
"added", len(ns),
"removed", removed,
)
}
// handleHostsUpdates receives the updates from the hosts container and adds
// them to the clients storage. It is intended to be used as a goroutine.
func (s *Storage) handleHostsUpdates() {
func (s *Storage) handleHostsUpdates(ctx context.Context) {
if s.etcHosts == nil {
return
}
defer log.OnPanic("storage")
defer slogutil.RecoverAndLog(ctx, s.logger)
for {
select {
@@ -265,7 +281,7 @@ func (s *Storage) handleHostsUpdates() {
return
}
s.addFromHostsFile(upd)
s.addFromHostsFile(ctx, upd)
case <-s.done:
return
}
@@ -274,7 +290,7 @@ func (s *Storage) handleHostsUpdates() {
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
func (s *Storage) addFromHostsFile(ctx context.Context, hosts *hostsfile.DefaultStorage) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -294,14 +310,19 @@ func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
})
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed)
s.logger.DebugContext(
ctx,
"updating client aliases from system hosts file",
"added", added,
"removed", removed,
)
}
// type check
var _ AddressUpdater = (*Storage)(nil)
// UpdateAddress implements the [AddressUpdater] interface for *Storage
func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
func (s *Storage) UpdateAddress(ctx context.Context, ip netip.Addr, host string, info *whois.Info) {
// Common fast path optimization.
if host == "" && info == nil {
return
@@ -315,12 +336,12 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
}
if info != nil {
s.setWHOISInfo(ip, info)
s.setWHOISInfo(ctx, ip, info)
}
}
// UpdateDHCP updates [SourceDHCP] runtime client information.
func (s *Storage) UpdateDHCP() {
func (s *Storage) UpdateDHCP(ctx context.Context) {
if s.dhcp == nil || !s.runtimeSourceDHCP {
return
}
@@ -338,14 +359,23 @@ func (s *Storage) UpdateDHCP() {
}
removed := s.runtimeIndex.removeEmpty()
log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed)
s.logger.DebugContext(
ctx,
"updating client aliases from dhcp",
"added", added,
"removed", removed,
)
}
// setWHOISInfo sets the WHOIS information for a runtime client.
func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
func (s *Storage) setWHOISInfo(ctx context.Context, ip netip.Addr, wi *whois.Info) {
_, ok := s.index.findByIP(ip)
if ok {
log.Debug("storage: client for %s is already created, ignore whois info", ip)
s.logger.DebugContext(
ctx,
"persistent client is already created, ignore whois info",
"ip", ip,
)
return
}
@@ -358,14 +388,14 @@ func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
rc.setWHOIS(wi)
log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi)
s.logger.DebugContext(ctx, "set whois info for runtime client", "ip", ip, "whois", wi)
}
// Add stores persistent client information or returns an error.
func (s *Storage) Add(p *Persistent) (err error) {
func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()
err = p.validate(s.allowedTags)
err = p.validate(ctx, s.logger, s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
@@ -388,7 +418,13 @@ func (s *Storage) Add(p *Persistent) (err error) {
s.index.add(p)
log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size())
s.logger.DebugContext(
ctx,
"client added",
"name", p.Name,
"ids", p.IDs(),
"clients_count", s.index.size(),
)
return nil
}
@@ -490,10 +526,10 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
// Update finds the stored persistent client by its name and updates its
// information from p.
func (s *Storage) Update(name string, p *Persistent) (err error) {
func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }()
err = p.validate(s.allowedTags)
err = p.validate(ctx, s.logger, s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err

View File

@@ -15,11 +15,25 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestStorage is a helper function that returns initialized storage.
func newTestStorage(tb testing.TB) (s *client.Storage) {
tb.Helper()
ctx := testutil.ContextWithTimeout(tb, testTimeout)
s, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(tb, err)
return s
}
// testHostsContainer is a mock implementation of the [client.HostsContainer]
// interface.
type testHostsContainer struct {
@@ -110,7 +124,9 @@ func TestStorage_Add_hostsfile(t *testing.T) {
onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh },
}
storage, err := client.NewStorage(&client.StorageConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: client.EmptyDHCP{},
EtcHosts: h,
ARPClientsUpdatePeriod: testTimeout / 10,
@@ -198,7 +214,9 @@ func TestStorage_Add_arp(t *testing.T) {
},
}
storage, err := client.NewStorage(&client.StorageConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: client.EmptyDHCP{},
ARPDB: a,
ARPClientsUpdatePeriod: testTimeout / 10,
@@ -273,8 +291,10 @@ func TestStorage_Add_whois(t *testing.T) {
cliName3 = "client_three"
)
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
ctx := testutil.ContextWithTimeout(t, testTimeout)
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: client.EmptyDHCP{},
})
require.NoError(t, err)
@@ -284,7 +304,7 @@ func TestStorage_Add_whois(t *testing.T) {
}
t.Run("new_client", func(t *testing.T) {
storage.UpdateAddress(cliIP1, "", whois)
storage.UpdateAddress(ctx, cliIP1, "", whois)
cli1 := storage.ClientRuntime(cliIP1)
require.NotNil(t, cli1)
@@ -292,8 +312,8 @@ func TestStorage_Add_whois(t *testing.T) {
})
t.Run("existing_runtime_client", func(t *testing.T) {
storage.UpdateAddress(cliIP2, cliName2, nil)
storage.UpdateAddress(cliIP2, "", whois)
storage.UpdateAddress(ctx, cliIP2, cliName2, nil)
storage.UpdateAddress(ctx, cliIP2, "", whois)
cli2 := storage.ClientRuntime(cliIP2)
require.NotNil(t, cli2)
@@ -304,14 +324,14 @@ func TestStorage_Add_whois(t *testing.T) {
})
t.Run("can't_set_persistent_client", func(t *testing.T) {
err = storage.Add(&client.Persistent{
err = storage.Add(ctx, &client.Persistent{
Name: cliName3,
UID: client.MustNewUID(),
IPs: []netip.Addr{cliIP3},
})
require.NoError(t, err)
storage.UpdateAddress(cliIP3, "", whois)
storage.UpdateAddress(ctx, cliIP3, "", whois)
rc := storage.ClientRuntime(cliIP3)
require.Nil(t, rc)
})
@@ -364,7 +384,9 @@ func TestClientsDHCP(t *testing.T) {
},
}
storage, err := client.NewStorage(&client.StorageConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: d,
RuntimeSourceDHCP: true,
})
@@ -378,7 +400,7 @@ func TestClientsDHCP(t *testing.T) {
})
t.Run("find_persistent", func(t *testing.T) {
err = storage.Add(&client.Persistent{
err = storage.Add(ctx, &client.Persistent{
Name: prsCliName,
UID: client.MustNewUID(),
MACs: []net.HardwareAddr{prsCliMAC},
@@ -393,7 +415,7 @@ func TestClientsDHCP(t *testing.T) {
t.Run("leases", func(t *testing.T) {
delete(ipToHost, cliIP1)
storage.UpdateDHCP()
storage.UpdateDHCP(ctx)
cli1 := storage.ClientRuntime(cliIP1)
require.Nil(t, cli1)
@@ -421,16 +443,19 @@ func TestClientsDHCP(t *testing.T) {
}
func TestClientsAddExisting(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
t.Run("simple", func(t *testing.T) {
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: client.EmptyDHCP{},
})
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
err = storage.Add(&client.Persistent{
err = storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
@@ -440,7 +465,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Now add an auto-client with the same IP.
storage.UpdateAddress(ip, "test", nil)
storage.UpdateAddress(ctx, ip, "test", nil)
rc := storage.ClientRuntime(ip)
assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test"))
})
@@ -468,8 +493,9 @@ func TestClientsAddExisting(t *testing.T) {
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
storage, err := client.NewStorage(&client.StorageConfig{
DHCP: dhcpServer,
storage, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: dhcpServer,
})
require.NoError(t, err)
@@ -484,7 +510,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
err = storage.Add(&client.Persistent{
err = storage.Add(ctx, &client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
@@ -492,7 +518,7 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the IP from the first client's IP range.
err = storage.Add(&client.Persistent{
err = storage.Add(ctx, &client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
@@ -506,14 +532,16 @@ func TestClientsAddExisting(t *testing.T) {
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
tb.Helper()
s, err := client.NewStorage(&client.StorageConfig{
DHCP: client.EmptyDHCP{},
ctx := testutil.ContextWithTimeout(tb, testTimeout)
s, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
DHCP: client.EmptyDHCP{},
})
require.NoError(tb, err)
for _, c := range m {
c.UID = client.MustNewUID()
require.NoError(tb, s.Add(c))
require.NoError(tb, s.Add(ctx, c))
}
require.Equal(tb, len(m), s.Size())
@@ -555,9 +583,8 @@ func TestStorage_Add(t *testing.T) {
UID: existingClientUID,
}
s, err := client.NewStorage(&client.StorageConfig{})
require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
s := newTestStorage(t)
tags := s.AllowedTags()
require.NotZero(t, len(tags))
require.True(t, slices.IsSorted(tags))
@@ -568,7 +595,7 @@ func TestStorage_Add(t *testing.T) {
_, ok = slices.BinarySearch(tags, notAllowedTag)
require.False(t, ok)
err = s.Add(existingClient)
err := s.Add(ctx, existingClient)
require.NoError(t, err)
testCases := []struct {
@@ -669,7 +696,7 @@ func TestStorage_Add(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = s.Add(tc.cli)
err = s.Add(ctx, tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
@@ -687,10 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) {
UID: client.MustNewUID(),
}
s, err := client.NewStorage(&client.StorageConfig{})
require.NoError(t, err)
err = s.Add(existingClient)
ctx := testutil.ContextWithTimeout(t, testTimeout)
s := newTestStorage(t)
err := s.Add(ctx, existingClient)
require.NoError(t, err)
testCases := []struct {
@@ -714,10 +740,8 @@ func TestStorage_RemoveByName(t *testing.T) {
}
t.Run("duplicate_remove", func(t *testing.T) {
s, err = client.NewStorage(&client.StorageConfig{})
require.NoError(t, err)
err = s.Add(existingClient)
s = newTestStorage(t)
err = s.Add(ctx, existingClient)
require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName))
@@ -1080,6 +1104,7 @@ func TestStorage_Update(t *testing.T) {
`uses the same ClientID "obstructing_client_id"`,
}}
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := newStorage(
@@ -1090,7 +1115,7 @@ func TestStorage_Update(t *testing.T) {
},
)
err := s.Update(clientName, tc.cli)
err := s.Update(ctx, clientName, tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}

View File

@@ -3,11 +3,12 @@ package dhcpsvc
import (
"fmt"
"log/slog"
"maps"
"os"
"slices"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mapsutil"
"github.com/AdguardTeam/golibs/netutil"
)
@@ -78,14 +79,13 @@ func (conf *Config) Validate() (err error) {
return errors.Join(errs...)
}
mapsutil.SortedRange(conf.Interfaces, func(iface string, ic *InterfaceConfig) (ok bool) {
for _, iface := range slices.Sorted(maps.Keys(conf.Interfaces)) {
ic := conf.Interfaces[iface]
err = ic.validate()
if err != nil {
errs = append(errs, fmt.Errorf("interface %q: %w", iface, err))
}
return true
})
}
return errors.Join(errs...)
}

View File

@@ -4,14 +4,15 @@ import (
"context"
"fmt"
"log/slog"
"maps"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mapsutil"
)
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
@@ -107,7 +108,8 @@ func newInterfaces(
v6 = make(dhcpInterfacesV6, 0, len(ifaces))
var errs []error
mapsutil.SortedRange(ifaces, func(name string, iface *InterfaceConfig) (cont bool) {
for _, name := range slices.Sorted(maps.Keys(ifaces)) {
iface := ifaces[name]
var i4 *dhcpInterfaceV4
i4, err = newDHCPInterfaceV4(ctx, l, name, iface.IPv4)
if err != nil {
@@ -120,9 +122,8 @@ func newInterfaces(
if i6 != nil {
v6 = append(v6, i6)
}
}
return true
})
if err = errors.Join(errs...); err != nil {
return nil, nil, err
}

View File

@@ -513,12 +513,14 @@ func TestSafeSearch(t *testing.T) {
SafeSearchCacheSize: 1000,
CacheTime: 30,
}
safeSearch, err := safesearch.NewDefault(
safeSearchConf,
"",
filterConf.SafeSearchCacheSize,
time.Minute*time.Duration(filterConf.CacheTime),
)
ctx := testutil.ContextWithTimeout(t, testTimeout)
safeSearch, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: safeSearchConf,
CacheSize: filterConf.SafeSearchCacheSize,
CacheTTL: time.Minute * time.Duration(filterConf.CacheTime),
})
require.NoError(t, err)
filterConf.SafeSearch = safeSearch
@@ -583,13 +585,15 @@ func TestSafeSearch(t *testing.T) {
t.Run(tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
// TODO(a.garipov): Create our own helper for this.
var reply *dns.Msg
require.Eventually(t, func() (ok bool) {
reply, _, err = client.Exchange(req, addr)
return err == nil
once := &sync.Once{}
require.EventuallyWithT(t, func(c *assert.CollectT) {
r, _, errExch := client.Exchange(req, addr)
if assert.NoError(c, errExch) {
once.Do(func() { reply = r })
}
}, testTimeout*10, testTimeout)
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
if tc.wantCNAME != "" {
require.Len(t, reply.Answer, 2)

View File

@@ -2,6 +2,7 @@ package dnsforward
import (
"cmp"
"context"
"encoding/binary"
"net"
"net/netip"
@@ -203,7 +204,8 @@ func (s *Server) processClientIP(addr netip.Addr) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
s.addrProc.Process(addr)
// TODO(s.chzhen): Pass context.
s.addrProc.Process(context.TODO(), addr)
}
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB

View File

@@ -2,6 +2,7 @@ package dnsforward
import (
"cmp"
"context"
"net"
"net/netip"
"testing"
@@ -90,7 +91,7 @@ func TestServer_ProcessInitial(t *testing.T) {
var gotAddr netip.Addr
s.addrProc = &aghtest.AddressProcessor{
OnProcess: func(ip netip.Addr) { gotAddr = ip },
OnProcess: func(ctx context.Context, ip netip.Addr) { gotAddr = ip },
OnClose: func() (err error) { panic("not implemented") },
}

View File

@@ -1057,7 +1057,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
}
}
err = os.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir)
err = aghos.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir)
if err != nil {
d.Close()

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -25,7 +26,7 @@ func (d *DNSFilter) validateFilterURL(urlStr string) (err error) {
if filepath.IsAbs(urlStr) {
urlStr = filepath.Clean(urlStr)
_, err = os.Stat(urlStr)
_, err = aghos.Stat(urlStr)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err

View File

@@ -1,15 +1,17 @@
package filtering
import "context"
// SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe
// for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
CheckHost(host string, qtype uint16) (res Result, err error)
CheckHost(ctx context.Context, host string, qtype uint16) (res Result, err error)
// Update updates the configuration of the safe search filter. Update must
// be safe for concurrent use. An implementation of Update may ignore some
// fields, but it must document which.
Update(conf SafeSearchConfig) (err error)
Update(ctx context.Context, conf SafeSearchConfig) (err error)
}
// SafeSearchConfig is a struct with safe search related settings.
@@ -40,10 +42,13 @@ func (d *DNSFilter) checkSafeSearch(
return Result{}, nil
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
clientSafeSearch := setts.ClientSafeSearch
if clientSafeSearch != nil {
return clientSafeSearch.CheckHost(host, qtype)
return clientSafeSearch.CheckHost(ctx, host, qtype)
}
return d.safeSearch.CheckHost(host, qtype)
return d.safeSearch.CheckHost(ctx, host, qtype)
}

View File

@@ -3,9 +3,11 @@ package safesearch
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"fmt"
"log/slog"
"net/netip"
"strings"
"sync"
@@ -14,13 +16,20 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/c2h5oh/datasize"
"github.com/miekg/dns"
)
// Attribute keys and values for logging.
const (
LogPrefix = "safesearch"
LogKeyClient = "client"
)
// Service is a enum with service names used as search providers.
type Service string
@@ -57,9 +66,32 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
}
}
// DefaultConfig is the configuration structure for [Default].
type DefaultConfig struct {
// Logger is used for logging the operation of the safe search filter.
Logger *slog.Logger
// ClientName is the name of the persistent client associated with the safe
// search filter, if there is one.
ClientName string
// CacheSize is the size of the filter results cache.
CacheSize uint
// CacheTTL is the Time to Live duration for cached items.
CacheTTL time.Duration
// ServicesConfig contains safe search settings for services. It must not
// be nil.
ServicesConfig filtering.SafeSearchConfig
}
// Default is the default safe search filter that uses filtering rules with the
// dnsrewrite modifier.
type Default struct {
// logger is used for logging the operation of the safe search filter.
logger *slog.Logger
// mu protects engine.
mu *sync.RWMutex
@@ -67,33 +99,28 @@ type Default struct {
// engine may be nil, which means that this safe search filter is disabled.
engine *urlfilter.DNSEngine
cache cache.Cache
logPrefix string
cacheTTL time.Duration
// cache stores safe search filtering results.
cache cache.Cache
// cacheTTL is the Time to Live duration for cached items.
cacheTTL time.Duration
}
// NewDefault returns an initialized default safe search filter. name is used
// for logging.
func NewDefault(
conf filtering.SafeSearchConfig,
name string,
cacheSize uint,
cacheTTL time.Duration,
) (ss *Default, err error) {
// NewDefault returns an initialized default safe search filter. ctx is used
// to log the initial refresh.
func NewDefault(ctx context.Context, conf *DefaultConfig) (ss *Default, err error) {
ss = &Default{
mu: &sync.RWMutex{},
logger: conf.Logger,
mu: &sync.RWMutex{},
cache: cache.New(cache.Config{
EnableLRU: true,
MaxSize: cacheSize,
MaxSize: conf.CacheSize,
}),
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
cacheTTL: cacheTTL,
cacheTTL: conf.CacheTTL,
}
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
// TODO(s.chzhen): Move to [Default.InitialRefresh].
err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf.ServicesConfig)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@@ -102,29 +129,15 @@ func NewDefault(
return ss, nil
}
// log is a helper for logging that includes the name of the safe search
// filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
func (ss *Default) log(level log.Level, msg string, args ...any) {
switch level {
case log.DEBUG:
log.Debug(ss.logPrefix+msg, args...)
case log.INFO:
log.Info(ss.logPrefix+msg, args...)
case log.ERROR:
log.Error(ss.logPrefix+msg, args...)
default:
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
}
}
// resetEngine creates new engine for provided safe search configuration and
// sets it in ss.
func (ss *Default) resetEngine(
ctx context.Context,
listID int,
conf filtering.SafeSearchConfig,
) (err error) {
if !conf.Enabled {
ss.log(log.INFO, "disabled")
ss.logger.DebugContext(ctx, "disabled")
return nil
}
@@ -149,7 +162,7 @@ func (ss *Default) resetEngine(
ss.engine = urlfilter.NewDNSEngine(rs)
ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
ss.logger.InfoContext(ctx, "reset rules", "count", ss.engine.RulesCount)
return nil
}
@@ -158,10 +171,14 @@ func (ss *Default) resetEngine(
var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for *Default.
func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) {
func (ss *Default) CheckHost(
ctx context.Context,
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
start := time.Now()
defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
ss.logger.DebugContext(ctx, "lookup finished", "host", host, "elapsed", time.Since(start))
}()
switch qtype {
@@ -172,9 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host, qtype)
cachedValue, isFound := ss.getCachedResult(ctx, host, qtype)
if isFound {
ss.log(log.DEBUG, "found in cache: %q", host)
ss.logger.DebugContext(ctx, "found in cache", "host", host)
return cachedValue, nil
}
@@ -186,7 +203,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
fltRes, err := ss.newResult(rewrite, qtype)
if err != nil {
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
ss.logger.ErrorContext(ctx, "looking up addresses", "host", host, slogutil.KeyError, err)
return filtering.Result{}, err
}
@@ -195,7 +212,7 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache.
ss.setCacheResult(host, qtype, res)
ss.setCacheResult(ctx, host, qtype, res)
return res, nil
}
@@ -255,7 +272,12 @@ func (ss *Default) newResult(
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
func (ss *Default) setCacheResult(
ctx context.Context,
host string,
qtype rules.RRType,
res filtering.Result,
) {
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, expire)
@@ -263,7 +285,7 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
err := gob.NewEncoder(buf).Encode(res)
if err != nil {
ss.log(log.ERROR, "cache encoding: %s", err)
ss.logger.ErrorContext(ctx, "cache encoding", slogutil.KeyError, err)
return
}
@@ -271,12 +293,18 @@ func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering
val := buf.Bytes()
_ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
ss.logger.DebugContext(
ctx,
"stored in cache",
"host", host,
"entry_size", datasize.ByteSize(len(val)),
)
}
// getCachedResult returns stored data from cache for host. qtype is expected
// to be either [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) getCachedResult(
ctx context.Context,
host string,
qtype rules.RRType,
) (res filtering.Result, ok bool) {
@@ -298,7 +326,7 @@ func (ss *Default) getCachedResult(
err := gob.NewDecoder(buf).Decode(&res)
if err != nil {
ss.log(log.ERROR, "cache decoding: %s", err)
ss.logger.ErrorContext(ctx, "cache decoding", slogutil.KeyError, err)
return filtering.Result{}, false
}
@@ -308,11 +336,11 @@ func (ss *Default) getCachedResult(
// Update implements the [filtering.SafeSearch] interface for *Default. Update
// ignores the CustomResolver and Enabled fields.
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
func (ss *Default) Update(ctx context.Context, conf filtering.SafeSearchConfig) (err error) {
ss.mu.Lock()
defer ss.mu.Unlock()
err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf)
err = ss.resetEngine(ctx, rulelist.URLFilterIDSafeSearch, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err

View File

@@ -6,6 +6,8 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -21,6 +23,9 @@ const (
testCacheTTL = 30 * time.Minute
)
// testTimeout is the common timeout for tests and contexts.
const testTimeout = 1 * time.Second
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
Bing: true,
@@ -35,7 +40,12 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
ss, err := NewDefault(testutil.ContextWithTimeout(t, testTimeout), &DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: ssConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err)
return ss
@@ -52,16 +62,17 @@ func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, testQType)
res, err := ss.CheckHost(ctx, domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, testQType)
res, err = ss.CheckHost(ctx, domain, testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
@@ -70,7 +81,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
cachedValue, isFound := ss.getCachedResult(ctx, domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)

View File

@@ -10,15 +10,15 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests and contexts.
const testTimeout = 1 * time.Second
// Common test constants.
const (
@@ -47,7 +47,13 @@ var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err)
hosts := []string{
@@ -82,7 +88,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
for _, host := range hosts {
// Check host for each domain.
var res filtering.Result
res, err = ss.CheckHost(host, tc.qt)
res, err = ss.CheckHost(ctx, host, tc.qt)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -103,7 +109,13 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
}
func TestDefault_CheckHost_google(t *testing.T) {
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err)
// Check host for each domain.
@@ -118,7 +130,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
} {
t.Run(host, func(t *testing.T) {
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
res, err = ss.CheckHost(ctx, host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -149,13 +161,19 @@ func (r *testResolver) LookupIP(
}
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
// the same as the one for Yandex IPv6. That is, a NODATA response.
res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA)
res, err := ss.CheckHost(ctx, "www.duckduckgo.com", dns.TypeAAAA)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -166,32 +184,38 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
func TestDefault_Update(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
})
require.NoError(t, err)
res, err := ss.CheckHost("www.yandex.com", testQType)
res, err := ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{
err = ss.Update(ctx, filtering.SafeSearchConfig{
Enabled: true,
Google: false,
})
require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType)
res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{
err = ss.Update(ctx, filtering.SafeSearchConfig{
Enabled: false,
Google: true,
})
require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType)
res, err = ss.CheckHost(ctx, "www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)

View File

@@ -51,7 +51,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
}
conf := *req
err = d.safeSearch.Update(conf)
err = d.safeSearch.Update(r.Context(), conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)

View File

@@ -2450,7 +2450,7 @@ var blockedServices = []blockedService{{
},
}, {
ID: "telegram",
Name: "Telegram",
Name: "Telegram (Web)",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M46.137,6.552c-0.75-0.636-1.928-0.727-3.146-0.238l-0.002,0C41.708,6.828,6.728,21.832,5.304,22.445 c-0.259,0.09-2.521,0.934-2.288,2.814c0.208,1.695,2.026,2.397,2.248,2.478l8.893,3.045c0.59,1.964,2.765,9.21,3.246,10.758 c0.3,0.965,0.789,2.233,1.646,2.494c0.752,0.29,1.5,0.025,1.984-0.355l5.437-5.043l8.777,6.845l0.209,0.125 c0.596,0.264,1.167,0.396,1.712,0.396c0.421,0,0.825-0.079,1.211-0.237c1.315-0.54,1.841-1.793,1.896-1.935l6.556-34.077 C47.231,7.933,46.675,7.007,46.137,6.552z M22,32l-3,8l-3-10l23-17L22,32z\" /></svg>"),
Rules: []string{
"||comments.app^",

View File

@@ -90,7 +90,11 @@ func InitAuth(
trustedProxies: trustedProxies,
}
var err error
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
opts := *bbolt.DefaultOptions
opts.OpenFile = aghos.OpenFile
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, &opts)
if err != nil {
log.Error("auth: open DB: %s: %s", dbFilename, err)
if err.Error() == "invalid argument" {

View File

@@ -3,6 +3,7 @@ package home
import (
"context"
"fmt"
"log/slog"
"net/netip"
"slices"
"sync"
@@ -13,17 +14,23 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
)
// clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct {
// baseLogger is used to create loggers with custom prefixes for safe search
// filter. It must not be nil.
baseLogger *slog.Logger
// storage stores information about persistent clients.
storage *client.Storage
@@ -61,6 +68,8 @@ type BlockedClientChecker interface {
// dhcpServer: optional
// Note: this function must be called only once
func (clients *clientsContainer) Init(
ctx context.Context,
baseLogger *slog.Logger,
objects []*clientObject,
dhcpServer client.DHCP,
etcHosts *aghnet.HostsContainer,
@@ -72,13 +81,14 @@ func (clients *clientsContainer) Init(
return errors.Error("clients container already initialized")
}
clients.baseLogger = baseLogger
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
confClients := make([]*client.Persistent, 0, len(objects))
for i, o := range objects {
var p *client.Persistent
p, err = o.toPersistent(clients.safeSearchCacheSize, clients.safeSearchCacheTTL)
p, err = o.toPersistent(ctx, baseLogger, clients.safeSearchCacheSize, clients.safeSearchCacheTTL)
if err != nil {
return fmt.Errorf("init persistent client at index %d: %w", i, err)
}
@@ -92,12 +102,13 @@ func (clients *clientsContainer) Init(
// TODO(e.burkov): The option should probably be returned, since hosts file
// currently used not only for clients' information enrichment, but also in
// the filtering module and upstream addresses resolution.
var hosts client.HostsContainer = etcHosts
if !config.Clients.Sources.HostsFile {
hosts = nil
var hosts client.HostsContainer
if config.Clients.Sources.HostsFile && etcHosts != nil {
hosts = etcHosts
}
clients.storage, err = client.NewStorage(&client.StorageConfig{
clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
Logger: baseLogger.With(slogutil.KeyPrefix, "client_storage"),
InitialClients: confClients,
DHCP: dhcpServer,
EtcHosts: hosts,
@@ -168,6 +179,8 @@ type clientObject struct {
// toPersistent returns an initialized persistent client if there are no errors.
func (o *clientObject) toPersistent(
ctx context.Context,
baseLogger *slog.Logger,
safeSearchCacheSize uint,
safeSearchCacheTTL time.Duration,
) (cli *client.Persistent, err error) {
@@ -203,14 +216,23 @@ func (o *clientObject) toPersistent(
}
if o.SafeSearchConf.Enabled {
err = cli.SetSafeSearch(
o.SafeSearchConf,
safeSearchCacheSize,
safeSearchCacheTTL,
logger := baseLogger.With(
slogutil.KeyPrefix, safesearch.LogPrefix,
safesearch.LogKeyClient, cli.Name,
)
var ss *safesearch.Default
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: logger,
ServicesConfig: o.SafeSearchConf,
ClientName: cli.Name,
CacheSize: safeSearchCacheSize,
CacheTTL: safeSearchCacheTTL,
})
if err != nil {
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
}
cli.SafeSearch = ss
}
if o.BlockedServices == nil {
@@ -396,6 +418,12 @@ func (clients *clientsContainer) UpstreamConfigByID(
)
c.UpstreamConfig = conf
// TODO(s.chzhen): Pass context.
err = clients.storage.Update(context.TODO(), c.Name, c)
if err != nil {
return nil, fmt.Errorf("setting upstream config: %w", err)
}
return conf, nil
}
@@ -404,8 +432,13 @@ var _ client.AddressUpdater = (*clientsContainer)(nil)
// UpdateAddress implements the [client.AddressUpdater] interface for
// *clientsContainer
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
clients.storage.UpdateAddress(ip, host, info)
func (clients *clientsContainer) UpdateAddress(
ctx context.Context,
ip netip.Addr,
host string,
info *whois.Info,
) {
clients.storage.UpdateAddress(ctx, ip, host, info)
}
// close gracefully closes all the client-specific upstream configurations of

View File

@@ -7,6 +7,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -20,16 +22,28 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
testing: true,
}
require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{}))
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := c.Init(
ctx,
slogutil.NewDiscardLogger(),
nil,
client.EmptyDHCP{},
nil,
nil,
&filtering.Config{},
)
require.NoError(t, err)
return c
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add client with upstreams.
err := clients.storage.Add(&client.Persistent{
err := clients.storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},

View File

@@ -1,6 +1,7 @@
package home
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -10,8 +11,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// clientJSON is a common structure used by several handlers to deal with
@@ -103,7 +106,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true
})
clients.storage.UpdateDHCP()
clients.storage.UpdateDHCP(r.Context())
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info()
@@ -181,6 +184,7 @@ func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err
// jsonToClient converts JSON object to persistent client object if there are no
// errors.
func (clients *clientsContainer) jsonToClient(
ctx context.Context,
cj clientJSON,
prev *client.Persistent,
) (c *client.Persistent, err error) {
@@ -207,14 +211,23 @@ func (clients *clientsContainer) jsonToClient(
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
if c.SafeSearchConf.Enabled {
err = c.SetSafeSearch(
c.SafeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
logger := clients.baseLogger.With(
slogutil.KeyPrefix, safesearch.LogPrefix,
safesearch.LogKeyClient, c.Name,
)
var ss *safesearch.Default
ss, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: logger,
ServicesConfig: c.SafeSearchConf,
ClientName: c.Name,
CacheSize: clients.safeSearchCacheSize,
CacheTTL: clients.safeSearchCacheTTL,
})
if err != nil {
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
}
c.SafeSearch = ss
}
return c, nil
@@ -321,14 +334,14 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
c, err := clients.jsonToClient(cj, nil)
c, err := clients.jsonToClient(r.Context(), cj, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.storage.Add(c)
err = clients.storage.Add(r.Context(), c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -391,14 +404,14 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
c, err := clients.jsonToClient(dj.Data, nil)
c, err := clients.jsonToClient(r.Context(), dj.Data, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.storage.Update(dj.Name, c)
err = clients.storage.Update(r.Context(), dj.Name, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@@ -11,14 +11,19 @@ import (
"net/url"
"slices"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is the common timeout for tests and contexts.
const testTimeout = 1 * time.Second
const (
testClientIP1 = "1.1.1.1"
testClientIP2 = "2.2.2.2"
@@ -103,9 +108,10 @@ func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*c
require.NoError(tb, err)
var got []*client.Persistent
ctx := testutil.ContextWithTimeout(tb, testTimeout)
for _, cj := range clientList.Clients {
var c *client.Persistent
c, err = clients.jsonToClient(*cj, nil)
c, err = clients.jsonToClient(ctx, *cj, nil)
require.NoError(tb, err)
got = append(got, c)
@@ -125,10 +131,11 @@ func assertPersistentClientsData(
tb.Helper()
var got []*client.Persistent
ctx := testutil.ContextWithTimeout(tb, testTimeout)
for _, cm := range data {
for _, cj := range cm {
var c *client.Persistent
c, err := clients.jsonToClient(*cj, nil)
c, err := clients.jsonToClient(ctx, *cj, nil)
require.NoError(tb, err)
got = append(got, c)
@@ -196,13 +203,14 @@ func TestClientsContainer_HandleAddClient(t *testing.T) {
func TestClientsContainer_HandleDelClient(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.storage.Add(clientOne)
err := clients.storage.Add(ctx, clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.storage.Add(clientTwo)
err = clients.storage.Add(ctx, clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
@@ -258,9 +266,10 @@ func TestClientsContainer_HandleDelClient(t *testing.T) {
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.storage.Add(clientOne)
err := clients.storage.Add(ctx, clientOne)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
@@ -341,12 +350,14 @@ func TestClientsContainer_HandleFindClient(t *testing.T) {
},
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.storage.Add(clientOne)
err := clients.storage.Add(ctx, clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.storage.Add(clientTwo)
err = clients.storage.Add(ctx, clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})

View File

@@ -708,7 +708,7 @@ func (c *configuration) write() (err error) {
return fmt.Errorf("generating config file: %w", err)
}
err = maybe.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile)
err = aghos.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile)
if err != nil {
return fmt.Errorf("writing config file: %w", err)
}

View File

@@ -7,6 +7,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -18,12 +20,15 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
tb.Helper()
s, err := client.NewStorage(&client.StorageConfig{})
ctx := testutil.ContextWithTimeout(tb, testTimeout)
s, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(tb, err)
for _, p := range clients {
p.UID = client.MustNewUID()
require.NoError(tb, s.Add(p))
require.NoError(tb, s.Add(ctx, p))
}
return s

View File

@@ -115,15 +115,16 @@ func Main(clientBuildFS fs.FS) {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
ctx := context.Background()
for {
sig := <-signals
log.Info("Received signal %q", sig)
switch sig {
case syscall.SIGHUP:
Context.clients.storage.ReloadARP()
Context.clients.storage.ReloadARP(ctx)
Context.tls.reload()
default:
cleanup(context.Background())
cleanup(ctx)
cleanupAlways()
close(done)
}
@@ -148,6 +149,14 @@ func setupContext(opts options) (err error) {
Context.tlsRoots = aghtls.SystemRootCAs()
Context.mux = http.NewServeMux()
if !opts.noEtcHosts {
err = setupHostsContainer()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
@@ -168,14 +177,6 @@ func setupContext(opts options) (err error) {
os.Exit(0)
}
if !opts.noEtcHosts {
err = setupHostsContainer()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
return nil
}
@@ -278,8 +279,8 @@ func setupOpts(opts options) (err error) {
}
// initContextClients initializes Context clients and related fields.
func initContextClients(logger *slog.Logger) (err error) {
err = setupDNSFilteringConf(config.Filtering)
func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
@@ -306,6 +307,8 @@ func initContextClients(logger *slog.Logger) (err error) {
}
return Context.clients.Init(
ctx,
logger,
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
@@ -355,7 +358,11 @@ func setupBindOpts(opts options) (err error) {
}
// setupDNSFilteringConf sets up DNS filtering configuration settings.
func setupDNSFilteringConf(conf *filtering.Config) (err error) {
func setupDNSFilteringConf(
ctx context.Context,
baseLogger *slog.Logger,
conf *filtering.Config,
) (err error) {
const (
dnsTimeout = 3 * time.Second
@@ -446,12 +453,13 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.ParentalBlockHost = host
}
conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf,
"default",
conf.SafeSearchCacheSize,
cacheTime,
)
logger := baseLogger.With(slogutil.KeyPrefix, safesearch.LogPrefix)
conf.SafeSearch, err = safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: logger,
ServicesConfig: conf.SafeSearchConf,
CacheSize: conf.SafeSearchCacheSize,
CacheTTL: cacheTime,
})
if err != nil {
return fmt.Errorf("initializing safesearch: %w", err)
}
@@ -584,7 +592,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// data first, but also to avoid relying on automatic Go init() function.
filtering.InitModule()
err = initContextClients(slogLogger)
// TODO(s.chzhen): Use it for the entire initialization process.
ctx := context.Background()
err = initContextClients(ctx, slogLogger)
fatalOnError(err)
err = setupOpts(opts)
@@ -632,7 +643,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
dataDir := Context.getDataDir()
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
err = aghos.MkdirAll(dataDir, aghos.DefaultPermDir)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
GLMode = opts.glinetMode

View File

@@ -24,10 +24,15 @@ func newSlogLogger(ls *logSettings) (l *slog.Logger) {
return slogutil.NewDiscardLogger()
}
lvl := slog.LevelInfo
if ls.Verbose {
lvl = slog.LevelDebug
}
return slogutil.New(&slogutil.Config{
Format: slogutil.FormatAdGuardLegacy,
Level: lvl,
AddTimestamp: true,
Verbose: ls.Verbose,
})
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/google/renameio/v2/maybe"
)
// signalHandler processes incoming signals and shuts services down.
@@ -142,7 +141,7 @@ func (h *signalHandler) writePID() {
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
data = append(data, '\n')
err := maybe.WriteFile(h.pidFile, data, 0o644)
err := aghos.WriteFile(h.pidFile, data, 0o644)
if err != nil {
log.Error("sighdlr: writing pidfile: %s", err)

View File

@@ -21,7 +21,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/v2/maybe"
"gopkg.in/yaml.v3"
)
@@ -183,7 +182,7 @@ func (m *Manager) write() (err error) {
return fmt.Errorf("encoding: %w", err)
}
err = maybe.WriteFile(m.fileName, b, aghos.DefaultPermFile)
err = aghos.WriteFile(m.fileName, b, aghos.DefaultPermFile)
if err != nil {
return fmt.Errorf("writing: %w", err)
}

View File

@@ -14,7 +14,7 @@ import (
//
// TODO(a.garipov): Consider ways to detect this better.
func NeedsMigration(confFilePath string) (ok bool) {
s, err := os.Stat(confFilePath)
s, err := aghos.Stat(confFilePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
// Likely a first run. Don't check.
@@ -70,7 +70,7 @@ func chmodFile(filePath string) {
// chmodPath changes the permissions of a single filesystem entity. The results
// are logged at the appropriate level.
func chmodPath(entPath, fileType string, fm fs.FileMode) {
err := os.Chmod(entPath, fm)
err := aghos.Chmod(entPath, fm)
if err == nil {
log.Info("permcheck: changed permissions for %s %q", fileType, entPath)

View File

@@ -60,7 +60,7 @@ func checkFile(filePath string) {
// checkPath checks the permissions of a single filesystem entity. The results
// are logged at the appropriate level.
func checkPath(entPath, fileType string, want fs.FileMode) {
s, err := os.Stat(entPath)
s, err := aghos.Stat(entPath)
if err != nil {
logFunc := log.Error
if errors.Is(err, os.ErrNotExist) {

View File

@@ -57,6 +57,7 @@ type qLogFile struct {
// newQLogFile initializes a new instance of the qLogFile.
func newQLogFile(path string) (qf *qLogFile, err error) {
// Don't use [aghos.OpenFile] here, because the file is expected to exist.
f, err := os.OpenFile(path, os.O_RDONLY, aghos.DefaultPermFile)
if err != nil {
return nil, err

View File

@@ -71,7 +71,7 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
filename := l.logFile
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, aghos.DefaultPermFile)
f, err := aghos.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, aghos.DefaultPermFile)
if err != nil {
return fmt.Errorf("creating file %q: %w", filename, err)
}

View File

@@ -384,7 +384,13 @@ func (s *StatsCtx) openDB() (err error) {
s.logger.Debug("opening database")
var db *bbolt.DB
db, err = bbolt.Open(s.filename, aghos.DefaultPermFile, nil)
opts := *bbolt.DefaultOptions
// Use the custom OpenFile function to properly handle access rights on
// Windows.
opts.OpenFile = aghos.OpenFile
db, err = bbolt.Open(s.filename, aghos.DefaultPermFile, &opts)
if err != nil {
if err.Error() == "invalid argument" {
const lines = `AdGuard Home cannot be initialized due to an incompatible file system.
@@ -467,7 +473,7 @@ func (s *StatsCtx) flushDB(id, limit uint32, ptr *unit) (cont bool, sleepFor tim
if delErr != nil {
// TODO(e.burkov): Improve the algorithm of deleting the oldest bucket
// to avoid the error.
lvl := slog.LevelWarn
lvl := slog.LevelDebug
if !errors.Is(delErr, bbolt.ErrBucketNotFound) {
isCommitable = false
lvl = slog.LevelError

View File

@@ -6,11 +6,11 @@ require (
github.com/fzipp/gocyclo v0.6.0
github.com/golangci/misspell v0.6.0
github.com/gordonklaus/ineffassign v0.1.0
github.com/kisielk/errcheck v1.7.0
github.com/kisielk/errcheck v1.8.0
github.com/kyoh86/looppointer v0.2.1
github.com/securego/gosec/v2 v2.21.4
github.com/uudashr/gocognit v1.1.3
golang.org/x/tools v0.25.0
golang.org/x/tools v0.26.0
golang.org/x/vuln v1.1.3
honnef.co/go/tools v0.5.1
mvdan.cc/gofumpt v0.7.0
@@ -18,12 +18,12 @@ require (
)
require (
cloud.google.com/go v0.115.1 // indirect
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/ai v0.8.2 // indirect
cloud.google.com/go/auth v0.9.7 // indirect
cloud.google.com/go/auth v0.9.9 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.2 // indirect
cloud.google.com/go/longrunning v0.6.1 // indirect
cloud.google.com/go/longrunning v0.6.2 // indirect
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -40,26 +40,26 @@ require (
github.com/kyoh86/nolint v0.0.1 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect
golang.org/x/exp/typeparams v0.0.0-20240909161429-701f63a606c0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/telemetry v0.0.0-20240927214544-e9e6960092dd // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/api v0.199.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240930140551-af27646dc61f // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
google.golang.org/api v0.203.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
google.golang.org/grpc v1.67.1 // indirect
google.golang.org/protobuf v1.34.2 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -1,16 +1,16 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.115.1 h1:Jo0SM9cQnSkYfp44+v+NQXHpcHqlnRJk2qxh6yvxxxQ=
cloud.google.com/go v0.115.1/go.mod h1:DuujITeaufu3gL68/lOFIirVNJwQeyf5UXyi+Wbgknc=
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/ai v0.8.2 h1:LEaQwqBv+k2ybrcdTtCTc9OPZXoEdcQaGrfvDYS6Bnk=
cloud.google.com/go/ai v0.8.2/go.mod h1:Wb3EUUGWwB6yHBaUf/+oxUq/6XbCaU1yh0GrwUS8lr4=
cloud.google.com/go/auth v0.9.7 h1:ha65jNwOfI48YmUzNfMaUDfqt5ykuYIUnSartpU1+BA=
cloud.google.com/go/auth v0.9.7/go.mod h1:Xo0n7n66eHyOWWCnitop6870Ilwo3PiZyodVkkH1xWM=
cloud.google.com/go/auth v0.9.9 h1:BmtbpNQozo8ZwW2t7QJjnrQtdganSdmqeIBxHxNkEZQ=
cloud.google.com/go/auth v0.9.9/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI=
cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY=
cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc=
cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo=
cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k=
cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTSYjyMc=
cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0=
cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc=
cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs=
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -86,8 +86,8 @@ github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/gordonklaus/ineffassign v0.1.0 h1:y2Gd/9I7MdY1oEIt+n+rowjBNDcLQq3RsH5hwJd0f9s=
github.com/gordonklaus/ineffassign v0.1.0/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/kisielk/errcheck v1.7.0 h1:+SbscKmWJ5mOK/bO1zS60F5I9WwZDWOfRsC4RwfwRV0=
github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/FqKluHJQ=
github.com/kisielk/errcheck v1.8.0 h1:ZX/URYa7ilESY19ik/vBmCn6zdGQLxACwjAcWbHlYlg=
github.com/kisielk/errcheck v1.8.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/FqKluHJQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -125,26 +125,26 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0 h1:hCq2hNMwsegUvPzI7sPOvtO9cqyy5GbWt/Ybp2xrx8Q=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.55.0/go.mod h1:LqaApwGx/oUmzsbqxkzuBvyoPpkxk3JQWnqfVrJ3wCA=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 h1:yMkBS9yViCc7U7yeLzJPM2XizlfdVvBRSmsQDWu6qc0=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0/go.mod h1:n8MR6/liuGB5EmTETUBeU5ZgqMOlqKRxUaqPQBOANZ8=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
golang.org/x/exp/typeparams v0.0.0-20240909161429-701f63a606c0 h1:bVwtbF629Xlyxk6xLQq2TDYmqP0uiWaet5LwRebuY0k=
golang.org/x/exp/typeparams v0.0.0-20240909161429-701f63a606c0/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY=
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c h1:F/15/6p7LyGUSoP0GE5CB/U9+TNEER1foNOP5sWLLnI=
golang.org/x/exp/typeparams v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@@ -161,8 +161,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
@@ -181,17 +181,17 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240927214544-e9e6960092dd h1:cSUM3UgI7q2QZ4WBwDOGo5eFhZG4eGUkpdFporYHwpQ=
golang.org/x/telemetry v0.0.0-20240927214544-e9e6960092dd/go.mod h1:PsFMgI0jiuY7j+qwXANuh9a/x5kQESTSnRow3gapUyk=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0 h1:od0RE4kmouF+x/o4zkTTSvBnGPZ2azGgCUmZdrbnEXM=
golang.org/x/telemetry v0.0.0-20241028140143-9c0d19e65ba0/go.mod h1:8nZWdGp9pq73ZI//QJyckMQab3yq7hoWi7SI0UIusVI=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
@@ -200,25 +200,25 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ=
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
golang.org/x/vuln v1.1.3 h1:NPGnvPOTgnjBc9HTaUx+nj+EaUYxl5SJOWqaDYGaFYw=
golang.org/x/vuln v1.1.3/go.mod h1:7Le6Fadm5FOqE9C926BCD0g12NWyhg7cxV4BwcPFuNY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.199.0 h1:aWUXClp+VFJmqE0JPvpZOK3LDQMyFKYIow4etYd9qxs=
google.golang.org/api v0.199.0/go.mod h1:ohG4qSztDJmZdjK/Ar6MhbAmb/Rpi4JHOqagsh90K28=
google.golang.org/api v0.203.0 h1:SrEeuwU3S11Wlscsn+LA1kb/Y5xT8uggJSkIhD08NAU=
google.golang.org/api v0.203.0/go.mod h1:BuOVyCSYEPwJb3npWvDnNmFI92f3GeRnHNkETneT3SI=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f h1:jTm13A2itBi3La6yTGqn8bVSrc3ZZ1r8ENHlIXBfnRA=
google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f/go.mod h1:CLGoBuH1VHxAUXVPP8FfPwPEVJB6lz3URE5mY2SuayE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240930140551-af27646dc61f h1:cUMEy+8oS78BWIH9OWazBkzbr090Od9tWBNtZHkOhf0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240930140551-af27646dc61f/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 h1:2oV8dfuIkM1Ti7DwXc0BJfnwr9csz4TDXI9EmiI+Rbw=
google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38/go.mod h1:vuAjtvlwkDKF6L1GQ0SokiRLCGFfeBUXWr/aFFkHACc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
@@ -235,8 +235,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -264,7 +264,7 @@ func (u *Updater) check() (err error) {
// ignores the configuration file if firstRun is true.
func (u *Updater) backup(firstRun bool) (err error) {
log.Debug("updater: backing up current configuration")
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
_ = aghos.Mkdir(u.backupDir, aghos.DefaultPermDir)
if !firstRun {
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
if err != nil {
@@ -338,12 +338,12 @@ func (u *Updater) downloadPackageFile() (err error) {
return fmt.Errorf("io.ReadAll() failed: %w", err)
}
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
_ = aghos.Mkdir(u.updateDir, aghos.DefaultPermDir)
log.Debug("updater: saving package to file")
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
err = aghos.WriteFile(u.packageName, body, aghos.DefaultPermFile)
if err != nil {
return fmt.Errorf("os.WriteFile() failed: %w", err)
return fmt.Errorf("writing package file: %w", err)
}
return nil
}
@@ -360,15 +360,15 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): This whole package needs to be
// rewritten and covered in more integration tests. It
// has weird assumptions and file mode issues.
// TODO(a.garipov): This whole package needs to be rewritten and
// covered in more integration tests. It has weird assumptions and
// file mode issues.
return "", nil
}
err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o755))
err = aghos.Mkdir(outputName, os.FileMode(hdr.Mode&0o755))
if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
}
log.Debug("updater: created directory %q", outputName)
@@ -383,7 +383,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
}
var wc io.WriteCloser
wc, err = os.OpenFile(
wc, err = aghos.OpenFile(
outputName,
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
os.FileMode(hdr.Mode&0o755),
@@ -464,14 +464,13 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): See the similar todo in
// tarGzFileUnpack.
// TODO(a.garipov): See the similar todo in tarGzFileUnpack.
return "", nil
}
err = os.Mkdir(outputName, fi.Mode())
err = aghos.Mkdir(outputName, fi.Mode())
if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
}
log.Debug("updater: created directory %q", outputName)
@@ -480,7 +479,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
}
var wc io.WriteCloser
wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
wc, err = aghos.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
return "", fmt.Errorf("os.OpenFile(): %w", err)
}
@@ -523,15 +522,19 @@ func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
}
// Copy file on disk
func copyFile(src, dst string) error {
d, e := os.ReadFile(src)
if e != nil {
return e
func copyFile(src, dst string) (err error) {
d, err := os.ReadFile(src)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
e = os.WriteFile(dst, d, aghos.DefaultPermFile)
if e != nil {
return e
err = aghos.WriteFile(dst, d, aghos.DefaultPermFile)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
return nil
}