Pull request: all: fix client upstreams, imp code

Updates #3186.

Squashed commit of the following:

commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri May 28 12:54:07 2021 +0300

    all: imp code, names

commit 98f86c21ae23b665095075feb4a59dcfcc622bc7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu May 27 21:11:37 2021 +0300

    all: fix client upstreams, imp code
This commit is contained in:
Ainar Garipov
2021-05-28 13:02:59 +03:00
parent 48b8579703
commit 3be783bd34
18 changed files with 249 additions and 270 deletions

View File

@@ -8,7 +8,9 @@ import (
"net/http"
"os"
"sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -27,11 +29,10 @@ type FilteringConfig struct {
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
//
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// GetCustomUpstreamByClient is a callback that returns upstreams
// configuration based on the client IP address or ClientID. It returns
// nil if there are no custom upstreams for the client.
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
// Protection configuration
// --
@@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
return nil
}
// isInSorted returns true if s is in the sorted slice strs.
func isInSorted(strs []string, s string) (ok bool) {
i := sort.SearchStrings(strs, s)
if i == len(strs) || strs[i] != s {
return false
}
return true
}
// isWildcard returns true if host is a wildcard hostname.
func isWildcard(host string) (ok bool) {
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
}
// matchesDomainWildcard returns true if host matches the domain wildcard
// pattern pat.
func matchesDomainWildcard(host, pat string) (ok bool) {
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
}
// anyNameMatches returns true if sni, the client's SNI value, matches any of
// the DNS names and patterns from certificate. dnsNames must be sorted.
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
if aghnet.ValidateDomainName(sni) != nil {
return false
}
if isInSorted(dnsNames, sni) {
return true
}
for _, dn := range dnsNames {
if matchesDomainWildcard(sni, dn) {
return true
}
}
return false
}
// Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
return nil, fmt.Errorf("invalid SNI")
}

View File

@@ -0,0 +1,53 @@
package dnsforward
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAnyNameMatches(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess
var ip net.IP
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc
}
@@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// ipStringFromAddr extracts an IP address string from net.Addr.
func ipStringFromAddr(addr net.Addr) (ipStr string) {
if ip := aghnet.IPFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
@@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := IPStringFromAddr(d.Addr)
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", clientIP)
// Use the clientID first, since it has a higher priority.
id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
if err != nil {
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
} else if upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", id)
d.CustomUpstreamConfig = upsConf
}
}

View File

@@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
require.Empty(t, proxyCtx.Res.Answer)
})
}
func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
})
t.Run("nil", func(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}

View File

@@ -12,7 +12,6 @@ import (
"math/big"
"net"
"os"
"sort"
"sync"
"testing"
"time"
@@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{
&aghtest.TestUpstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
},
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := &aghtest.TestUpstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
}
return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
}, nil
}
startDeferStop(t, s)
@@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
}
}
func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
})
t.Run("nil", func(t *testing.T) {
assert.Empty(t, IPStringFromAddr(nil))
})
}
func TestMatchDNSName(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
})
}
}
type testDHCP struct{}
func (d *testDHCP) Enabled() (ok bool) { return true }

View File

@@ -4,15 +4,15 @@ import (
"fmt"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPFromAddr(d.Addr)
ip := aghnet.IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
@@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
if s.conf.FilterHandler != nil {
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
}
return &setts

View File

@@ -4,6 +4,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
@@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: IPFromAddr(pctx.Addr),
ClientIP: aghnet.IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}
@@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}

View File

@@ -1,69 +0,0 @@
package dnsforward
import (
"net"
"sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
)
// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
case *net.TCPAddr:
return addr.IP
}
return nil
}
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := IPFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
// Find value in a sorted array
func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val)
if i == len(ar) || ar[i] != val {
return -1
}
return i
}
func isWildcard(host string) bool {
return len(host) >= 2 &&
host[0] == '*' && host[1] == '.'
}
// Return TRUE if host name matches a wildcard pattern
func matchDomainWildcard(host, wildcard string) bool {
return isWildcard(wildcard) &&
strings.HasSuffix(host, wildcard[1:])
}
// Return TRUE if client's SNI value matches DNS names from certificate
func matchDNSName(dnsNames []string, sni string) bool {
if aghnet.ValidateDomainName(sni) != nil {
return false
}
if findSorted(dnsNames, sni) != -1 {
return true
}
for _, dn := range dnsNames {
if matchDomainWildcard(sni, dn) {
return true
}
}
return false
}

View File

@@ -1,60 +0,0 @@
package dnsforward
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
// fakeAddr is a mock implementation of net.Addr interface to simplify testing.
type fakeAddr struct {
// Addr is embedded here simply to make fakeAddr a net.Addr without
// actually implementing all methods.
net.Addr
}
func TestIPFromAddr(t *testing.T) {
supIPv4 := net.IP{1, 2, 3, 4}
supIPv6 := net.ParseIP("2a00:1450:400c:c06::93")
testCases := []struct {
name string
addr net.Addr
want net.IP
}{{
name: "ipv4_tcp",
addr: &net.TCPAddr{
IP: supIPv4,
},
want: supIPv4,
}, {
name: "ipv6_tcp",
addr: &net.TCPAddr{
IP: supIPv6,
},
want: supIPv6,
}, {
name: "ipv4_udp",
addr: &net.UDPAddr{
IP: supIPv4,
},
want: supIPv4,
}, {
name: "ipv6_udp",
addr: &net.UDPAddr{
IP: supIPv6,
},
want: supIPv6,
}, {
name: "non-ip_addr",
addr: &fakeAddr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, IPFromAddr(tc.addr))
})
}
}