-(dnsforward): custom client per-domain upstreams
Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539
This commit is contained in:
@@ -26,8 +26,9 @@ type FilteringConfig struct {
|
||||
// Filtering callback function
|
||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
|
||||
// This callback function returns the list of upstream servers for a client specified by IP address
|
||||
GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
|
||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
||||
|
||||
// Protection configuration
|
||||
// --
|
||||
@@ -102,11 +103,10 @@ type TLSConfig struct {
|
||||
// ServerConfig represents server configuration.
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
UDPListenAddr *net.UDPAddr // UDP listen address
|
||||
TCPListenAddr *net.TCPAddr // TCP listen address
|
||||
Upstreams []upstream.Upstream // Configured upstreams
|
||||
DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
|
||||
OnDNSRequest func(d *proxy.DNSContext)
|
||||
UDPListenAddr *net.UDPAddr // UDP listen address
|
||||
TCPListenAddr *net.TCPAddr // TCP listen address
|
||||
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
|
||||
OnDNSRequest func(d *proxy.DNSContext)
|
||||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
@@ -132,22 +132,21 @@ var defaultValues = ServerConfig{
|
||||
// createProxyConfig creates and validates configuration for the main proxy
|
||||
func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
proxyConfig := proxy.Config{
|
||||
UDPListenAddr: s.conf.UDPListenAddr,
|
||||
TCPListenAddr: s.conf.TCPListenAddr,
|
||||
Ratelimit: int(s.conf.Ratelimit),
|
||||
RatelimitWhitelist: s.conf.RatelimitWhitelist,
|
||||
RefuseAny: s.conf.RefuseAny,
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: int(s.conf.CacheSize),
|
||||
CacheMinTTL: s.conf.CacheMinTTL,
|
||||
CacheMaxTTL: s.conf.CacheMaxTTL,
|
||||
Upstreams: s.conf.Upstreams,
|
||||
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
AllServers: s.conf.AllServers,
|
||||
EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet,
|
||||
FindFastestAddr: s.conf.FastestAddr,
|
||||
UDPListenAddr: s.conf.UDPListenAddr,
|
||||
TCPListenAddr: s.conf.TCPListenAddr,
|
||||
Ratelimit: int(s.conf.Ratelimit),
|
||||
RatelimitWhitelist: s.conf.RatelimitWhitelist,
|
||||
RefuseAny: s.conf.RefuseAny,
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: int(s.conf.CacheSize),
|
||||
CacheMinTTL: s.conf.CacheMinTTL,
|
||||
CacheMaxTTL: s.conf.CacheMaxTTL,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
BeforeRequestHandler: s.beforeRequestHandler,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
AllServers: s.conf.AllServers,
|
||||
EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet,
|
||||
FindFastestAddr: s.conf.FastestAddr,
|
||||
}
|
||||
|
||||
if len(s.conf.BogusNXDomain) > 0 {
|
||||
@@ -168,7 +167,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
}
|
||||
|
||||
// Validate proxy config
|
||||
if len(proxyConfig.Upstreams) == 0 {
|
||||
if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
|
||||
return proxyConfig, errors.New("no upstream servers configured")
|
||||
}
|
||||
|
||||
@@ -204,18 +203,16 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err)
|
||||
}
|
||||
s.conf.Upstreams = upstreamConfig.Upstreams
|
||||
s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
|
||||
s.conf.UpstreamConfig = &upstreamConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries
|
||||
func (s *Server) prepareIntlProxy() {
|
||||
intlProxyConfig := proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
Upstreams: s.conf.Upstreams,
|
||||
DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
}
|
||||
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
|
||||
}
|
||||
|
||||
@@ -325,7 +325,9 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.dnsProxy.Upstreams = []upstream.Upstream{u}
|
||||
s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{u},
|
||||
}
|
||||
return s.dnsProxy.Start()
|
||||
}
|
||||
|
||||
@@ -353,8 +355,8 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
// but protection is disabled - response is NOT blocked
|
||||
req := createTestMessage("badhost.")
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
}
|
||||
|
||||
func TestBlockCNAME(t *testing.T) {
|
||||
@@ -368,23 +370,23 @@ func TestBlockCNAME(t *testing.T) {
|
||||
// response is blocked
|
||||
req := createTestMessage("badhost.")
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
||||
assert.Nil(t, err, nil)
|
||||
assert.Equal(t, dns.RcodeNameError, reply.Rcode)
|
||||
|
||||
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
|
||||
// but 'whitelist.example.org' is in a whitelist:
|
||||
// response isn't blocked
|
||||
req = createTestMessage("whitelist.example.org.")
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
|
||||
// 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters:
|
||||
// response is blocked
|
||||
req = createTestMessage("example.org.")
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, reply.Rcode == dns.RcodeNameError)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeNameError, reply.Rcode)
|
||||
|
||||
_ = s.Stop()
|
||||
}
|
||||
@@ -455,7 +457,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
|
||||
func TestBlockedCustomIP(t *testing.T) {
|
||||
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
|
||||
filters := []dnsfilter.Filter{dnsfilter.Filter{
|
||||
filters := []dnsfilter.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
c := dnsfilter.Config{}
|
||||
@@ -475,27 +477,27 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
conf.BlockingIPv4 = "0.0.0.1"
|
||||
conf.BlockingIPv6 = "::1"
|
||||
err = s.Prepare(&conf)
|
||||
assert.True(t, err == nil)
|
||||
assert.Nil(t, err)
|
||||
err = s.Start()
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
req := createTestMessageWithType("null.example.org.", dns.TypeA)
|
||||
reply, err := dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, len(reply.Answer) == 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(reply.Answer))
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, a.A.String() == "0.0.0.1")
|
||||
assert.Equal(t, "0.0.0.1", a.A.String())
|
||||
|
||||
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
|
||||
reply, err = dns.Exchange(req, addr.String())
|
||||
assert.True(t, err == nil)
|
||||
assert.True(t, len(reply.Answer) == 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(reply.Answer))
|
||||
a6, ok := reply.Answer[0].(*dns.AAAA)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, a6.AAAA.String() == "::1")
|
||||
assert.Equal(t, "::1", a6.AAAA.String())
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
@@ -598,7 +600,7 @@ func createTestServer(t *testing.T) *Server {
|
||||
127.0.0.1 host.example.org
|
||||
@@||whitelist.example.org^
|
||||
||127.0.0.255`
|
||||
filters := []dnsfilter.Filter{dnsfilter.Filter{
|
||||
filters := []dnsfilter.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
c := dnsfilter.Config{}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := ipFromAddr(d.Addr)
|
||||
if s.access.IsBlockedIP(ip) {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
|
||||
@@ -31,7 +31,7 @@ const (
|
||||
)
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{srv: s, proxyCtx: d}
|
||||
ctx.result = &dnsfilter.Result{}
|
||||
ctx.startTime = time.Now()
|
||||
@@ -124,12 +124,12 @@ func processUpstream(ctx *dnsContext) int {
|
||||
return resultDone // response is already set - nothing to do
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
clientIP := ipFromAddr(d.Addr)
|
||||
upstreams := s.conf.GetUpstreamsByClient(clientIP)
|
||||
if len(upstreams) > 0 {
|
||||
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
|
||||
if upstreamsConf != nil {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
d.Upstreams = upstreams
|
||||
d.CustomUpstreamConfig = upstreamsConf
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user