+ control, dns, client: add ability to set DNS upstream per domain
This commit is contained in:
100
control.go
100
control.go
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
"github.com/miekg/dns"
|
||||
govalidator "gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
@@ -317,11 +318,10 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, u := range newconfig.Upstreams {
|
||||
if err = validateUpstream(u); err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
|
||||
return
|
||||
}
|
||||
err = validateUpstreams(newconfig.Upstreams)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "wrong upstreams specification: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
config.DNS.UpstreamDNS = defaultDNS
|
||||
@@ -346,18 +346,81 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func validateUpstream(upstream string) error {
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(upstream, proto) {
|
||||
return nil
|
||||
// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
||||
func validateUpstreams(upstreams []string) error {
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
d, err := validateUpstream(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check this flag until default upstream will not be found
|
||||
if !defaultUpstreamFound {
|
||||
defaultUpstreamFound = d
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(upstream, "://") {
|
||||
return fmt.Errorf("wrong protocol")
|
||||
// Return error if there are no default upstreams
|
||||
if !defaultUpstreamFound {
|
||||
return fmt.Errorf("no default upstreams specified")
|
||||
}
|
||||
|
||||
return checkPlainDNS(upstream)
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateUpstream(u string) (defaultUpstream bool, err error) {
|
||||
// Check if user tries to specify upstream for domain
|
||||
defaultUpstream = true
|
||||
u, defaultUpstream, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// The special server address '#' means "use the default servers"
|
||||
if u == "#" && !defaultUpstream {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return error if the upstream contains '://' without any valid protocol
|
||||
if strings.Contains(u, "://") {
|
||||
return defaultUpstream, fmt.Errorf("wrong protocol")
|
||||
}
|
||||
|
||||
// Check if upstream is valid plain DNS
|
||||
return defaultUpstream, checkPlainDNS(u)
|
||||
}
|
||||
|
||||
// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
|
||||
// error will be returned if upstream per domain specification is invalid
|
||||
func separateUpstream(upstream string) (string, bool, error) {
|
||||
defaultUpstream := true
|
||||
if strings.HasPrefix(upstream, "[/") {
|
||||
defaultUpstream = false
|
||||
// split domains and upstream string
|
||||
domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
|
||||
if len(domainsAndUpstream) != 2 {
|
||||
return "", defaultUpstream, fmt.Errorf("wrong DNS upstream per domain specification: %s", upstream)
|
||||
}
|
||||
|
||||
// split domains list and validate each one
|
||||
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
|
||||
if host != "" {
|
||||
if err := utils.IsValidHostname(host); err != nil {
|
||||
return "", defaultUpstream, err
|
||||
}
|
||||
}
|
||||
}
|
||||
upstream = domainsAndUpstream[1]
|
||||
}
|
||||
return upstream, defaultUpstream, nil
|
||||
}
|
||||
|
||||
// checkPlainDNS checks if host is plain DNS
|
||||
@@ -425,7 +488,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func checkDNS(input string, bootstrap []string) error {
|
||||
if err := validateUpstream(input); err != nil {
|
||||
// separate upstream from domains list
|
||||
input, defaultUpstream, err := separateUpstream(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %s", err)
|
||||
}
|
||||
|
||||
// No need to check this entrance
|
||||
if input == "#" && !defaultUpstream {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := validateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %s", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user