-(dnsforward): custom client per-domain upstreams

Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539
This commit is contained in:
Andrey Meshkov
2020-05-13 20:31:43 +03:00
parent 1f954ab673
commit 67a39045fc
10 changed files with 106 additions and 88 deletions

View File

@@ -11,11 +11,12 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
@@ -41,11 +42,12 @@ type Client struct {
BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests
// Upstream objects:
// Custom upstream config for this client
// nil: not yet initialized
// not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used
upstreamObjects []upstream.Upstream
upstreamConfig *proxy.UpstreamConfig
}
type clientSource uint
@@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return c, true
}
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
a2 := make([]upstream.Upstream, len(a))
copy(a2, a)
return a2
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
return nil
}
if c.upstreamObjects == nil {
c.upstreamObjects = make([]upstream.Upstream, 0)
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
c.upstreamObjects = append(c.upstreamObjects, u)
if len(c.Upstreams) == 0 {
return nil
}
if c.upstreamConfig == nil {
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err == nil {
c.upstreamConfig = &config
}
}
if len(c.upstreamObjects) == 0 {
return nil
}
return upstreamArrayCopy(c.upstreamObjects)
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
@@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}
// update upstreams cache
c.upstreamObjects = nil
c.upstreamConfig = nil
*old = c
return nil