-(dnsforward): custom client per-domain upstreams

Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539
This commit is contained in:
Andrey Meshkov
2020-05-13 20:31:43 +03:00
parent 1f954ab673
commit 67a39045fc
10 changed files with 106 additions and 88 deletions

View File

@@ -11,11 +11,12 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
@@ -41,11 +42,12 @@ type Client struct {
BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests
// Upstream objects:
// Custom upstream config for this client
// nil: not yet initialized
// not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used
upstreamObjects []upstream.Upstream
upstreamConfig *proxy.UpstreamConfig
}
type clientSource uint
@@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return c, true
}
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
a2 := make([]upstream.Upstream, len(a))
copy(a2, a)
return a2
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
return nil
}
if c.upstreamObjects == nil {
c.upstreamObjects = make([]upstream.Upstream, 0)
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
c.upstreamObjects = append(c.upstreamObjects, u)
if len(c.Upstreams) == 0 {
return nil
}
if c.upstreamConfig == nil {
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err == nil {
c.upstreamConfig = &config
}
}
if len(c.upstreamObjects) == 0 {
return nil
}
return upstreamArrayCopy(c.upstreamObjects)
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
@@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}
// update upstreams cache
c.upstreamObjects = nil
c.upstreamConfig = nil
*old = c
return nil

View File

@@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
assert.Nil(t, err)
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
assert.Nil(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
@@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig {
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetUpstreamsByClient = getUpstreamsByClient
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
return newconfig
}
@@ -222,10 +221,6 @@ func getDNSAddresses() []string {
return dnsAddresses
}
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
return Context.clients.FindUpstreams(clientAddr)
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)