+ control, dns, client: add ability to set DNS upstream per domain

This commit is contained in:
Aleksey Dmitrevskiy
2019-03-20 14:24:33 +03:00
parent 6f56eb4c12
commit 9ea5c1abe1
10 changed files with 200 additions and 49 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
"github.com/miekg/dns"
govalidator "gopkg.in/asaskevich/govalidator.v4"
)
@@ -317,11 +318,10 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
return
}
for _, u := range newconfig.Upstreams {
if err = validateUpstream(u); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
return
}
err = validateUpstreams(newconfig.Upstreams)
if err != nil {
httpError(w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
return
}
config.DNS.UpstreamDNS = defaultDNS
@@ -346,18 +346,81 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func validateUpstream(upstream string) error {
for _, proto := range protocols {
if strings.HasPrefix(upstream, proto) {
return nil
// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
func validateUpstreams(upstreams []string) error {
var defaultUpstreamFound bool
for _, u := range upstreams {
d, err := validateUpstream(u)
if err != nil {
return err
}
// Check this flag until default upstream will not be found
if !defaultUpstreamFound {
defaultUpstreamFound = d
}
}
if strings.Contains(upstream, "://") {
return fmt.Errorf("wrong protocol")
// Return error if there are no default upstreams
if !defaultUpstreamFound {
return fmt.Errorf("no default upstreams specified")
}
return checkPlainDNS(upstream)
return nil
}
func validateUpstream(u string) (defaultUpstream bool, err error) {
// Check if user tries to specify upstream for domain
defaultUpstream = true
u, defaultUpstream, err = separateUpstream(u)
if err != nil {
return
}
// The special server address '#' means "use the default servers"
if u == "#" && !defaultUpstream {
return
}
// Check if the upstream has a valid protocol prefix
for _, proto := range protocols {
if strings.HasPrefix(u, proto) {
return
}
}
// Return error if the upstream contains '://' without any valid protocol
if strings.Contains(u, "://") {
return defaultUpstream, fmt.Errorf("wrong protocol")
}
// Check if upstream is valid plain DNS
return defaultUpstream, checkPlainDNS(u)
}
// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
// error will be returned if upstream per domain specification is invalid
func separateUpstream(upstream string) (string, bool, error) {
defaultUpstream := true
if strings.HasPrefix(upstream, "[/") {
defaultUpstream = false
// split domains and upstream string
domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
if len(domainsAndUpstream) != 2 {
return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream)
}
// split domains list and validate each one
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
if host != "" {
if err := utils.IsValidHostname(host); err != nil {
return "", defaultUpstream, err
}
}
}
upstream = domainsAndUpstream[1]
}
return upstream, defaultUpstream, nil
}
// checkPlainDNS checks if host is plain DNS
@@ -425,7 +488,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
func checkDNS(input string, bootstrap []string) error {
if err := validateUpstream(input); err != nil {
// separate upstream from domains list
input, defaultUpstream, err := separateUpstream(input)
if err != nil {
return fmt.Errorf("wrong upstream format: %s", err)
}
// No need to check this entrance
if input == "#" && !defaultUpstream {
return nil
}
if _, err := validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %s", err)
}