-(dnsforward): custom client per-domain upstreams
Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539
This commit is contained in:
@@ -11,11 +11,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
@@ -41,11 +42,12 @@ type Client struct {
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
// Upstream objects:
|
||||
|
||||
// Custom upstream config for this client
|
||||
// nil: not yet initialized
|
||||
// not nil, but empty: initialized, no good upstreams
|
||||
// not nil, not empty: Upstreams ready to be used
|
||||
upstreamObjects []upstream.Upstream
|
||||
upstreamConfig *proxy.UpstreamConfig
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
@@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
|
||||
a2 := make([]upstream.Upstream, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamObjects == nil {
|
||||
c.upstreamObjects = make([]upstream.Upstream, 0)
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
c.upstreamObjects = append(c.upstreamObjects, u)
|
||||
if len(c.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &config
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.upstreamObjects) == 0 {
|
||||
return nil
|
||||
}
|
||||
return upstreamArrayCopy(c.upstreamObjects)
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
@@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
c.upstreamObjects = nil
|
||||
c.upstreamConfig = nil
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
|
||||
@@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add client with upstreams
|
||||
client := Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
}
|
||||
ok, err := clients.Add(client)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, 1, len(config.Upstreams))
|
||||
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
@@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
return newconfig
|
||||
}
|
||||
|
||||
@@ -222,10 +221,6 @@ func getDNSAddresses() []string {
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
Reference in New Issue
Block a user