Pull request: 3381 check private domains
Merge in DNS/adguard-home from 3381-validate-privateness to master Closes #3381. Squashed commit of the following: commit 21cb12d10b07bb0bf0578db74ca9ac7b3ac5ae14 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Feb 14 16:29:59 2022 +0300 all: imp code, docs commit 39793551438cbea71e6ec78d0e05bee2d8dba3e5 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Feb 14 15:08:36 2022 +0300 all: imp code, docs commit 6b71848fd0980582b1bfe24a34f48608795e9b7d Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Mon Feb 14 14:22:00 2022 +0300 all: check private domains
This commit is contained in:
@@ -5,10 +5,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -41,7 +43,7 @@ type dnsConfig struct {
|
||||
LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"`
|
||||
}
|
||||
|
||||
func (s *Server) getDNSConfig() dnsConfig {
|
||||
func (s *Server) getDNSConfig() (c *dnsConfig) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
@@ -70,7 +72,7 @@ func (s *Server) getDNSConfig() dnsConfig {
|
||||
upstreamMode = "parallel"
|
||||
}
|
||||
|
||||
return dnsConfig{
|
||||
return &dnsConfig{
|
||||
Upstreams: &upstreams,
|
||||
UpstreamsFile: &upstreamFile,
|
||||
Bootstraps: &bootstraps,
|
||||
@@ -106,7 +108,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// since there is no need to omit it while decoding from JSON.
|
||||
DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
|
||||
}{
|
||||
dnsConfig: s.getDNSConfig(),
|
||||
dnsConfig: *s.getDNSConfig(),
|
||||
DefautLocalPTRUpstreams: defLocalPTRUps,
|
||||
}
|
||||
|
||||
@@ -138,39 +140,63 @@ func (req *dnsConfig) checkBlockingMode() bool {
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkUpstreamsMode() bool {
|
||||
if req.UpstreamMode == nil {
|
||||
return true
|
||||
}
|
||||
valid := []string{"", "fastest_addr", "parallel"}
|
||||
|
||||
for _, valid := range []string{
|
||||
"",
|
||||
"fastest_addr",
|
||||
"parallel",
|
||||
} {
|
||||
if *req.UpstreamMode == valid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode)
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
if req.Bootstraps == nil {
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, boot := range *req.Bootstraps {
|
||||
if boot == "" {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
var b string
|
||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }()
|
||||
|
||||
for _, b = range *req.Bootstraps {
|
||||
if b == "" {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
if _, err = upstream.NewResolver(b, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = req.checkBootstrap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkBlockingMode():
|
||||
return errors.Error("blocking_mode: incorrect value")
|
||||
case !req.checkUpstreamsMode():
|
||||
return errors.Error("upstream_mode: incorrect value")
|
||||
case !req.checkCacheTTL():
|
||||
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (req *dnsConfig) checkCacheTTL() bool {
|
||||
@@ -190,69 +216,33 @@ func (req *dnsConfig) checkCacheTTL() bool {
|
||||
}
|
||||
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := dnsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
req := &dnsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.Upstreams != nil {
|
||||
if err = ValidateUpstreams(*req.Upstreams); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var errBoot string
|
||||
if errBoot, err = req.checkBootstrap(); err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"%s can not be used as bootstrap dns cause: %s",
|
||||
errBoot,
|
||||
err,
|
||||
)
|
||||
err = req.validate(s.subnetDetector)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case !req.checkBlockingMode():
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")
|
||||
|
||||
return
|
||||
case !req.checkUpstreamsMode():
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")
|
||||
|
||||
return
|
||||
case !req.checkCacheTTL():
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"cache_ttl_min must be less or equal than cache_ttl_max",
|
||||
)
|
||||
|
||||
return
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
restart := s.setConfig(req)
|
||||
s.conf.ConfigModified()
|
||||
|
||||
if restart {
|
||||
if err = s.Reconfigure(nil); err != nil {
|
||||
err = s.Reconfigure(nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) {
|
||||
if dc.Upstreams != nil {
|
||||
s.conf.UpstreamDNS = *dc.Upstreams
|
||||
restart = true
|
||||
@@ -273,9 +263,9 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.RateLimit != nil {
|
||||
restart = restart || s.conf.Ratelimit != *dc.RateLimit
|
||||
if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit {
|
||||
s.conf.Ratelimit = *dc.RateLimit
|
||||
restart = true
|
||||
}
|
||||
|
||||
if dc.EDNSCSEnabled != nil {
|
||||
@@ -306,7 +296,7 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) {
|
||||
return restart
|
||||
}
|
||||
|
||||
func (s *Server) setConfig(dc dnsConfig) (restart bool) {
|
||||
func (s *Server) setConfig(dc *dnsConfig) (restart bool) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
@@ -353,52 +343,117 @@ type upstreamJSON struct {
|
||||
PrivateUpstreams []string `json:"private_upstream"`
|
||||
}
|
||||
|
||||
// IsCommentOrEmpty returns true of the string starts with a "#" character or is
|
||||
// an empty string. This function is useful for filtering out non-upstream
|
||||
// lines from upstream configs.
|
||||
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
||||
// This function is useful for filtering out non-upstream lines from upstream
|
||||
// configs.
|
||||
func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// LocalNetChecker is used to check if the IP address belongs to a local
|
||||
// network.
|
||||
type LocalNetChecker interface {
|
||||
// IsLocallyServedNetwork returns true if ip is contained in any of address
|
||||
// registries defined by RFC 6303.
|
||||
IsLocallyServedNetwork(ip net.IP) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
||||
// slice already so that this function may be considered useless.
|
||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
||||
// No need to validate comments and empty lines.
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
// Consider this case valid since it means the default server should be
|
||||
// used.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
// No need to validate comments
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
_, err = newUpstreamConfig(upstreams)
|
||||
|
||||
// Consider this case valid because defaultDNS will be used
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
)
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. lnc must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(u)
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
errs = append(errs, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
defaultUpstreamFound = useDefault
|
||||
if !lnc.IsLocallyServedNetwork(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultUpstreamFound {
|
||||
return fmt.Errorf("no default upstreams specified")
|
||||
if len(errs) > 0 {
|
||||
return errors.List("checking domain-specific upstreams", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user