Pull request: all: add idna handling, imp domain validation

Updates #2915.

Squashed commit of the following:

commit b907324426c87ee7334edbd61e43c44444ad27a9
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 7 16:26:41 2021 +0300

    all: imp docs, upd

commit c022f75cac006e077095cad283fea0a91d3a0eea
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 7 15:51:30 2021 +0300

    all: add idna handling, imp domain validation
This commit is contained in:
Ainar Garipov
2021-04-07 16:36:38 +03:00
parent 00a61fdea0
commit c133b01ef7
13 changed files with 375 additions and 215 deletions

View File

@@ -8,10 +8,11 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
"github.com/miekg/dns"
)
@@ -302,7 +303,7 @@ type upstreamJSON struct {
}
// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
func ValidateUpstreams(upstreams []string) error {
func ValidateUpstreams(upstreams []string) (err error) {
// No need to validate comments
upstreams = filterOutComments(upstreams)
@@ -311,7 +312,7 @@ func ValidateUpstreams(upstreams []string) error {
return nil
}
_, err := proxy.ParseUpstreamsConfig(
_, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: []string{},
@@ -345,56 +346,61 @@ func ValidateUpstreams(upstreams []string) error {
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
func validateUpstream(u string) (bool, error) {
// Check if user tries to specify upstream for domain
u, defaultUpstream, err := separateUpstream(u)
// Check if the user tries to specify upstream for domain.
u, useDefault, err := separateUpstream(u)
if err != nil {
return defaultUpstream, err
return useDefault, err
}
// The special server address '#' means "use the default servers"
if u == "#" && !defaultUpstream {
return defaultUpstream, nil
if u == "#" && !useDefault {
return useDefault, nil
}
// Check if the upstream has a valid protocol prefix
for _, proto := range protocols {
if strings.HasPrefix(u, proto) {
return defaultUpstream, nil
return useDefault, nil
}
}
// Return error if the upstream contains '://' without any valid protocol
if strings.Contains(u, "://") {
return defaultUpstream, fmt.Errorf("wrong protocol")
return useDefault, fmt.Errorf("wrong protocol")
}
// Check if upstream is valid plain DNS
return defaultUpstream, checkPlainDNS(u)
return useDefault, checkPlainDNS(u)
}
// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
// error will be returned if upstream per domain specification is invalid
func separateUpstream(upstream string) (string, bool, error) {
defaultUpstream := true
if strings.HasPrefix(upstream, "[/") {
defaultUpstream = false
// split domains and upstream string
domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
if len(domainsAndUpstream) != 2 {
return "", defaultUpstream, fmt.Errorf("wrong dns upstream per domain specification: %s", upstream)
// separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr)
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil
}
parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 {
return "", false, agherr.Error("duplicated separator")
}
domains := parts[0]
upstream = parts[1]
for i, host := range strings.Split(domains, "/") {
if host == "" {
continue
}
// split domains list and validate each one
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
if host != "" {
if err := utils.IsValidHostname(host); err != nil {
return "", defaultUpstream, err
}
}
err = aghnet.ValidateDomainName(host)
if err != nil {
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
}
upstream = domainsAndUpstream[1]
}
return upstream, defaultUpstream, nil
return upstream, false, nil
}
// checkPlainDNS checks if host is plain DNS
@@ -462,13 +468,13 @@ func checkDNS(input string, bootstrap []string) error {
}
// separate upstream from domains list
input, defaultUpstream, err := separateUpstream(input)
input, useDefault, err := separateUpstream(input)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
// No need to check this DNS server
if !defaultUpstream {
if !useDefault {
return nil
}