Pull request: 3142 custom private subnets

Merge in DNS/adguard-home from 3142-custom-subnets to master

Updates #3142.

Squashed commit of the following:

commit 11469ade75b9dc32ee6d93e3aa35cf79dbaa28b2
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Mar 17 19:56:02 2022 +0300

    all: upd golibs, use subnet set
This commit is contained in:
Eugene Burkov
2022-03-18 13:37:27 +03:00
parent 573cbafe3f
commit 778585865e
15 changed files with 105 additions and 520 deletions

View File

@@ -252,7 +252,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
return rc
}
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
dctx.isLocalClient = s.privateNets.Contains(ip)
return rc
}
@@ -374,7 +374,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally-served or at least
// don't need to be inaccessible externally.
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
if !s.privateNets.Contains(ip) {
log.Debug("dns: addr %s is not from locally-served network", ip)
return resultCodeSuccess
@@ -481,7 +481,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
if !s.privateNets.Contains(ip) {
return resultCodeSuccess
}

View File

@@ -4,35 +4,41 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_ProcessDetermineLocal(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
s := &Server{
subnetDetector: snd,
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliIP net.IP
want bool
}{{
want: assert.True,
name: "local",
cliIP: net.IP{192, 168, 0, 1},
want: true,
}, {
want: assert.False,
name: "external",
cliIP: net.IP{250, 249, 0, 1},
want: false,
}, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
}}
for _, tc := range testCases {
@@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
s.processDetermineLocal(dctx)
assert.Equal(t, tc.want, dctx.isLocalClient)
tc.want(t, dctx.isLocalClient)
})
}
}

View File

@@ -74,7 +74,7 @@ type Server struct {
localDomainSuffix string
ipset ipsetCtx
subnetDetector *aghnet.SubnetDetector
privateNets netutil.SubnetSet
localResolvers *proxy.Proxy
sysResolvers aghnet.SystemResolvers
recDetector *recursionDetector
@@ -111,13 +111,13 @@ const defaultLocalDomainSuffix = ".lan."
// DNSCreateParams are parameters to create a new server.
type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Stats
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
SubnetDetector *aghnet.SubnetDetector
Anonymizer *aghnet.IPMut
LocalDomain string
DNSFilter *filtering.DNSFilter
Stats stats.Stats
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
LocalDomain string
}
// domainNameToSuffix converts a domain name into a local domain suffix.
@@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
subnetDetector: p.SubnetDetector,
privateNets: p.PrivateNets,
localDomainSuffix: localDomainSuffix,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
@@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
}
resolver := s.internalProxy
if s.subnetDetector.IsLocallyServedNetwork(ip) {
if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS {
return "", nil
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
@@ -69,14 +70,11 @@ func createTestServer(
f := filtering.New(filterConf, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var err error
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
f := filtering.New(&filtering.Config{}, filters)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -913,15 +906,10 @@ func TestRewrite(t *testing.T) {
f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1028,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
SubnetDetector: snd,
s, err := NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1105,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
}, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: flt,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)
@@ -1227,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) {
srv.conf.ResolveClients = true
srv.conf.UsePrivateRDNS = true
var err error
srv.subnetDetector, err = aghnet.NewSubnetDetector()
require.NoError(t, err)
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := net.IP{192, 168, 1, 1}
testCases := []struct {

View File

@@ -4,7 +4,6 @@ import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
f := filtering.New(&filtering.Config{}, filters)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
require.NotNil(t, snd)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DNSFilter: f,
SubnetDetector: snd,
DHCPServer: &testDHCP{},
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
require.NoError(t, err)

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -167,7 +166,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
}
// validate returns an error if any field of req is invalid.
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
if err != nil {
@@ -176,7 +175,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
}
if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil {
return fmt.Errorf("validating private upstream servers: %w", err)
}
@@ -224,7 +223,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return
}
err = req.validate(s.subnetDetector)
err = req.validate(s.privateNets)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -350,17 +349,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// LocalNetChecker is used to check if the IP address belongs to a local
// network.
type LocalNetChecker interface {
// IsLocallyServedNetwork returns true if ip is contained in any of address
// registries defined by RFC 6303.
IsLocallyServedNetwork(ip net.IP) (ok bool)
}
// type check
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
@@ -422,8 +410,8 @@ func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. lnc must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
if err != nil {
return err
@@ -444,7 +432,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
continue
}
if !lnc.IsLocallyServedNetwork(subnet.IP) {
if !privateNets.Contains(subnet.IP) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -410,8 +411,7 @@ func TestValidateUpstreams(t *testing.T) {
}
func TestValidateUpstreamsPrivate(t *testing.T) {
snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err)
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
@@ -452,7 +452,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err = ValidateUpstreamsPrivate(set, snd)
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}