all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-09-11 17:51:50 +03:00
parent 7b93f5d7cf
commit 258eecc55b
48 changed files with 265 additions and 178 deletions

View File

@@ -3,7 +3,7 @@ package confmigrate_test
import (
"io/fs"
"os"
"path/filepath"
"path"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/confmigrate"
@@ -12,6 +12,9 @@ import (
yaml "gopkg.in/yaml.v3"
)
// testdata is a virtual filesystem containing test data.
var testdata = os.DirFS("testdata")
// getField returns the value located at the given indexes in the given object.
// It fails the test if the value is not found or of the expected type. The
// indexes can be either strings or integers, and are interpreted as map keys or
@@ -42,9 +45,6 @@ func getField[T any](t require.TestingT, obj any, indexes ...any) (val T) {
return obj.(T)
}
// testdata is a virtual filesystem containing test data.
var testdata = os.DirFS("testdata")
func TestMigrateConfig_Migrate(t *testing.T) {
const (
inputFileName = "input.yml"
@@ -189,10 +189,10 @@ func TestMigrateConfig_Migrate(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body, err := fs.ReadFile(testdata, filepath.Join(t.Name(), inputFileName))
body, err := fs.ReadFile(testdata, path.Join(t.Name(), inputFileName))
require.NoError(t, err)
wantBody, err := fs.ReadFile(testdata, filepath.Join(t.Name(), outputFileName))
wantBody, err := fs.ReadFile(testdata, path.Join(t.Name(), outputFileName))
require.NoError(t, err)
migrator := confmigrate.New(&confmigrate.Config{

View File

@@ -54,7 +54,7 @@ func TestServer_handleDHCPStatus(t *testing.T) {
assert.JSONEq(t, b.String(), w.Body.String())
}
// defaultResponse is a helper that returs the response with default
// defaultResponse is a helper that returns the response with default
// configuration.
defaultResponse := func() *dhcpStatusResponse {
conf4 := defaultV4ServerConf()

View File

@@ -614,6 +614,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
// setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (err error) {
fallbacks := s.conf.FallbackDNS
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
if len(fallbacks) == 0 {
return nil
}

View File

@@ -372,6 +372,27 @@ func TestServer_timeout(t *testing.T) {
})
}
func TestServer_Prepare_fallbacks(t *testing.T) {
srvConf := &ServerConfig{
Config: Config{
FallbackDNS: []string{
"#tls://1.1.1.1",
"8.8.8.8",
},
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}
s, err := NewServer(DNSCreateParams{})
require.NoError(t, err)
err = s.Prepare(srvConf)
require.NoError(t, err)
require.NotNil(t, s.dnsProxy.Fallbacks)
assert.Len(t, s.dnsProxy.Fallbacks.Upstreams, 1)
}
func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,

View File

@@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
@@ -140,15 +141,15 @@ func (s *Server) filterRewritten(
// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
r *filtering.Result,
err error,
) {
func (s *Server) checkHostRules(
host string,
rrtype rules.RRType,
setts *filtering.Settings,
) (r *filtering.Result, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
res, err := s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil {
return nil, err
}
@@ -156,20 +157,21 @@ func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Set
return &res, err
}
// filterDNSResponse checks each resource record of the response's answer
// section from pctx and returns a non-nil res if at least one of canonical
// names or IP addresses in it matches the filtering rules.
func (s *Server) filterDNSResponse(
pctx *proxy.DNSContext,
setts *filtering.Settings,
) (res *filtering.Result, err error) {
// filterDNSResponse checks each resource record of answer section of
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
setts := dctx.setts
if !setts.FilteringEnabled {
return nil, nil
return nil
}
for _, a := range pctx.Res.Answer {
var res *filtering.Result
pctx := dctx.proxyCtx
for i, a := range pctx.Res.Answer {
host := ""
var rrtype uint16
var rrtype rules.RRType
switch a := a.(type) {
case *dns.CNAME:
host = strings.TrimSuffix(a.Target, ".")
@@ -195,18 +197,19 @@ func (s *Server) filterDNSResponse(
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
if err != nil {
return nil, err
} else if res == nil {
continue
} else if res.IsFiltered {
return fmt.Errorf("filtering answer at index %d: %w", i, err)
} else if res != nil && res.IsFiltered {
dctx.result = res
dctx.origResp = pctx.Res
pctx.Res = s.genDNSFilterMessage(pctx, res)
log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)
return res, nil
break
}
}
return nil, nil
return nil
}
// removeIPv6Hints deletes IPv6 hints from RR values.

View File

@@ -328,26 +328,34 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
}
res, rErr := s.filterDNSResponse(pctx, &filtering.Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
})
require.NoError(t, rErr)
dctx := &dnsContext{
proxyCtx: pctx,
setts: &filtering.Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
},
}
fltErr := s.filterDNSResponse(dctx)
require.NoError(t, fltErr)
res := dctx.result
if tc.wantRule == "" {
assert.Nil(t, res)
return
}
want := &filtering.Result{
wantResult := &filtering.Result{
IsFiltered: true,
Reason: filtering.FilteredBlockList,
Rules: []*filtering.ResultRule{{
Text: tc.wantRule,
}},
}
assert.Equal(t, want, res)
assert.Equal(t, wantResult, res)
assert.Equal(t, resp, dctx.origResp)
})
}
}

View File

@@ -749,6 +749,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)

View File

@@ -577,6 +577,14 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
badUps + ` over tcp: dns: id mismatch`,
},
name: "fallback_broken",
}, {
body: map[string]any{
"fallback_dns": []string{goodUps, "#this.is.comment"},
},
wantResp: map[string]any{
goodUps: "OK",
},
name: "fallback_comment_mix",
}}
for _, tc := range testCases {

View File

@@ -671,11 +671,11 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
}
// Apply filtering logic
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req")
if ctx.proxyCtx.Res != nil {
if dctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
}
@@ -684,8 +684,8 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode)
defer s.serverLock.RUnlock()
var err error
if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
ctx.err = err
if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
dctx.err = err
return resultCodeError
}
@@ -857,7 +857,6 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
log.Debug("dnsforward: started processing filtering after resp")
defer log.Debug("dnsforward: finished processing filtering after resp")
pctx := dctx.proxyCtx
switch res := dctx.result; res.Reason {
case filtering.NotFilteredAllowList:
return resultCodeSuccess
@@ -871,6 +870,7 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
}
pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
@@ -880,13 +880,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
default:
return s.filterAfterResponse(dctx, pctx)
return s.filterAfterResponse(dctx)
}
}
// filterAfterResponse returns the result of filtering the response that wasn't
// explicitly allowed or rewritten.
func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (res resultCode) {
func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
@@ -894,17 +894,12 @@ func (s *Server) filterAfterResponse(dctx *dnsContext, pctx *proxy.DNSContext) (
return resultCodeSuccess
}
result, err := s.filterDNSResponse(pctx, dctx.setts)
err := s.filterDNSResponse(dctx)
if err != nil {
dctx.err = err
return resultCodeError
}
if result != nil {
dctx.result = result
dctx.origResp = pctx.Res
}
return resultCodeSuccess
}

View File

@@ -282,6 +282,12 @@ type statsConfig struct {
Enabled bool `yaml:"enabled"`
}
// Default block host constants.
const (
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
// config is the global configuration structure.
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
@@ -389,6 +395,9 @@ var config = &configuration{
Schedule: schedule.EmptyWeekly(),
IDs: []string{},
},
ParentalBlockHost: defaultParentalBlockHost,
SafeBrowsingBlockHost: defaultSafeBrowsingBlockHost,
},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",

View File

@@ -359,9 +359,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
pcService = "parental control"
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
pcTXTSuffix = `pc.dns.adguard.com.`
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
defaultParentalBlockHost = "family-block.dns.adguard.com"
)
conf.EtcHosts = Context.etcHosts
@@ -398,8 +395,15 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
CacheSize: conf.SafeBrowsingCacheSize,
})
if conf.SafeBrowsingBlockHost != "" {
conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
// Protect against invalid configuration, see #6181.
//
// TODO(a.garipov): Validate against an empty host instead of setting it to
// default.
if conf.SafeBrowsingBlockHost == "" {
host := defaultSafeBrowsingBlockHost
log.Info("%s: warning: empty blocking host; using default: %q", sbService, host)
conf.SafeBrowsingBlockHost = host
}
parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts)
@@ -415,8 +419,15 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
CacheSize: conf.ParentalCacheSize,
})
if conf.ParentalBlockHost != "" {
conf.ParentalBlockHost = defaultParentalBlockHost
// Protect against invalid configuration, see #6181.
//
// TODO(a.garipov): Validate against an empty host instead of setting it to
// default.
if conf.ParentalBlockHost == "" {
host := defaultParentalBlockHost
log.Info("%s: warning: empty blocking host; using default: %q", pcService, host)
conf.ParentalBlockHost = host
}
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}

View File

@@ -72,7 +72,7 @@ type Entry struct {
Time time.Duration
}
// validate returs an error if entry is not valid.
// validate returns an error if entry is not valid.
func (e *Entry) validate() (err error) {
switch {
case e.Result == 0:
@@ -295,7 +295,7 @@ func loadUnitFromDB(tx *bbolt.Tx, id uint32) (udb *unitDB) {
return udb
}
// deserealize assigns the appropriate values from udb to u. u must not be nil.
// deserialize assigns the appropriate values from udb to u. u must not be nil.
// It's safe for concurrent use.
func (u *unit) deserialize(udb *unitDB) {
if udb == nil {

View File

@@ -113,9 +113,9 @@ func (u *Updater) Update(firstRun bool) (err error) {
log.Info("updater: updating")
defer func() {
if err != nil {
log.Error("updater: failed: %v", err)
log.Info("updater: failed")
} else {
log.Info("updater: finished")
log.Info("updater: finished successfully")
}
}()
@@ -240,18 +240,24 @@ func (u *Updater) unpack() error {
// check returns an error if the configuration file couldn't be used with the
// version of AdGuard Home just downloaded.
func (u *Updater) check() error {
func (u *Updater) check() (err error) {
log.Debug("updater: checking configuration")
err := copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"))
err = copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"))
if err != nil {
return fmt.Errorf("copyFile() failed: %w", err)
}
const format = "executing configuration check command: %w %d:\n" +
"below is the output of configuration check:\n" +
"%s" +
"end of the output"
cmd := exec.Command(u.updateExeName, "--check-config")
err = cmd.Run()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
out, err := cmd.CombinedOutput()
code := cmd.ProcessState.ExitCode()
if err != nil || code != 0 {
return fmt.Errorf(format, err, code, out)
}
return nil