Pull request: all: fix client upstreams, imp code
Updates #3186. Squashed commit of the following: commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri May 28 12:54:07 2021 +0300 all: imp code, names commit 98f86c21ae23b665095075feb4a59dcfcc622bc7 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu May 27 21:11:37 2021 +0300 all: fix client upstreams, imp code
This commit is contained in:
@@ -8,7 +8,9 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -27,11 +29,10 @@ type FilteringConfig struct {
|
||||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
|
||||
|
||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||
//
|
||||
// TODO(e.burkov): Replace argument type with net.IP.
|
||||
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
||||
// GetCustomUpstreamByClient is a callback that returns upstreams
|
||||
// configuration based on the client IP address or ClientID. It returns
|
||||
// nil if there are no custom upstreams for the client.
|
||||
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
|
||||
|
||||
// Protection configuration
|
||||
// --
|
||||
@@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isInSorted returns true if s is in the sorted slice strs.
|
||||
func isInSorted(strs []string, s string) (ok bool) {
|
||||
i := sort.SearchStrings(strs, s)
|
||||
if i == len(strs) || strs[i] != s {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isWildcard returns true if host is a wildcard hostname.
|
||||
func isWildcard(host string) (ok bool) {
|
||||
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
|
||||
}
|
||||
|
||||
// matchesDomainWildcard returns true if host matches the domain wildcard
|
||||
// pattern pat.
|
||||
func matchesDomainWildcard(host, pat string) (ok bool) {
|
||||
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
|
||||
}
|
||||
|
||||
// anyNameMatches returns true if sni, the client's SNI value, matches any of
|
||||
// the DNS names and patterns from certificate. dnsNames must be sorted.
|
||||
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if isInSorted(dnsNames, sni) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, dn := range dnsNames {
|
||||
if matchesDomainWildcard(sni, dn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Called by 'tls' package when Client Hello is received
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
|
||||
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
|
||||
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||
return nil, fmt.Errorf("invalid SNI")
|
||||
}
|
||||
|
||||
53
internal/dnsforward/config_test.go
Normal file
53
internal/dnsforward/config_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAnyNameMatches(t *testing.T) {
|
||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||
sort.Strings(dnsNames)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dnsName string
|
||||
want bool
|
||||
}{{
|
||||
name: "match",
|
||||
dnsName: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "b.a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "1.2.3.4",
|
||||
want: true,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "*.host2",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
return rc
|
||||
}
|
||||
|
||||
@@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// ipStringFromAddr extracts an IP address string from net.Addr.
|
||||
func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := aghnet.IPFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
@@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
clientIP := IPStringFromAddr(d.Addr)
|
||||
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", clientIP)
|
||||
// Use the clientID first, since it has a higher priority.
|
||||
id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
|
||||
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
|
||||
if err != nil {
|
||||
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
||||
} else if upsConf != nil {
|
||||
log.Debug("dns: using custom upstreams for client %s", id)
|
||||
d.CustomUpstreamConfig = upsConf
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||
require.Empty(t, proxyCtx.Res.Answer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
assert.Empty(t, ipStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
},
|
||||
}
|
||||
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
|
||||
return &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{
|
||||
&aghtest.TestUpstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
},
|
||||
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
|
||||
ups := &aghtest.TestUpstream{
|
||||
IPv4: map[string][]net.IP{
|
||||
"host.": {{192, 168, 0, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
return &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
}, nil
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
@@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("1:2:3::4"),
|
||||
Port: 12345,
|
||||
Zone: "eth0",
|
||||
}
|
||||
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
assert.Empty(t, IPStringFromAddr(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchDNSName(t *testing.T) {
|
||||
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
|
||||
sort.Strings(dnsNames)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dnsName string
|
||||
want bool
|
||||
}{{
|
||||
name: "match",
|
||||
dnsName: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "b.a.host2",
|
||||
want: true,
|
||||
}, {
|
||||
name: "match",
|
||||
dnsName: "1.2.3.4",
|
||||
want: true,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "",
|
||||
want: false,
|
||||
}, {
|
||||
name: "mismatch",
|
||||
dnsName: "*.host2",
|
||||
want: false,
|
||||
}}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testDHCP struct{}
|
||||
|
||||
func (d *testDHCP) Enabled() (ok bool) { return true }
|
||||
|
||||
@@ -4,15 +4,15 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := IPFromAddr(d.Addr)
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
@@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
|
||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
if s.conf.FilterHandler != nil {
|
||||
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
@@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: IPFromAddr(pctx.Addr),
|
||||
ClientIP: aghnet.IPFromAddr(pctx.Addr),
|
||||
ClientID: ctx.clientID,
|
||||
}
|
||||
|
||||
@@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
|
||||
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
|
||||
e.Client = ip.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
)
|
||||
|
||||
// IPFromAddr gets IP address from addr.
|
||||
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP
|
||||
case *net.TCPAddr:
|
||||
return addr.IP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPStringFromAddr extracts IP address from net.Addr.
|
||||
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
|
||||
func IPStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := IPFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find value in a sorted array
|
||||
func findSorted(ar []string, val string) int {
|
||||
i := sort.SearchStrings(ar, val)
|
||||
if i == len(ar) || ar[i] != val {
|
||||
return -1
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func isWildcard(host string) bool {
|
||||
return len(host) >= 2 &&
|
||||
host[0] == '*' && host[1] == '.'
|
||||
}
|
||||
|
||||
// Return TRUE if host name matches a wildcard pattern
|
||||
func matchDomainWildcard(host, wildcard string) bool {
|
||||
return isWildcard(wildcard) &&
|
||||
strings.HasSuffix(host, wildcard[1:])
|
||||
}
|
||||
|
||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||
func matchDNSName(dnsNames []string, sni string) bool {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if findSorted(dnsNames, sni) != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, dn := range dnsNames {
|
||||
if matchDomainWildcard(sni, dn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// fakeAddr is a mock implementation of net.Addr interface to simplify testing.
|
||||
type fakeAddr struct {
|
||||
// Addr is embedded here simply to make fakeAddr a net.Addr without
|
||||
// actually implementing all methods.
|
||||
net.Addr
|
||||
}
|
||||
|
||||
func TestIPFromAddr(t *testing.T) {
|
||||
supIPv4 := net.IP{1, 2, 3, 4}
|
||||
supIPv6 := net.ParseIP("2a00:1450:400c:c06::93")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
want net.IP
|
||||
}{{
|
||||
name: "ipv4_tcp",
|
||||
addr: &net.TCPAddr{
|
||||
IP: supIPv4,
|
||||
},
|
||||
want: supIPv4,
|
||||
}, {
|
||||
name: "ipv6_tcp",
|
||||
addr: &net.TCPAddr{
|
||||
IP: supIPv6,
|
||||
},
|
||||
want: supIPv6,
|
||||
}, {
|
||||
name: "ipv4_udp",
|
||||
addr: &net.UDPAddr{
|
||||
IP: supIPv4,
|
||||
},
|
||||
want: supIPv4,
|
||||
}, {
|
||||
name: "ipv6_udp",
|
||||
addr: &net.UDPAddr{
|
||||
IP: supIPv6,
|
||||
},
|
||||
want: supIPv6,
|
||||
}, {
|
||||
name: "non-ip_addr",
|
||||
addr: &fakeAddr{},
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, IPFromAddr(tc.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user