Pull request: 3142 custom private subnets
Merge in DNS/adguard-home from 3142-custom-subnets to master Updates #3142. Squashed commit of the following: commit 11469ade75b9dc32ee6d93e3aa35cf79dbaa28b2 Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Thu Mar 17 19:56:02 2022 +0300 all: upd golibs, use subnet set
This commit is contained in:
@@ -252,7 +252,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
return rc
|
||||
}
|
||||
|
||||
dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip)
|
||||
dctx.isLocalClient = s.privateNets.Contains(ip)
|
||||
|
||||
return rc
|
||||
}
|
||||
@@ -374,7 +374,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
return resultCodeSuccess
|
||||
@@ -481,7 +481,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if !s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if !s.privateNets.Contains(ip) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,41 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
s := &Server{
|
||||
subnetDetector: snd,
|
||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
cliIP net.IP
|
||||
want bool
|
||||
}{{
|
||||
want: assert.True,
|
||||
name: "local",
|
||||
cliIP: net.IP{192, 168, 0, 1},
|
||||
want: true,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "external",
|
||||
cliIP: net.IP{250, 249, 0, 1},
|
||||
want: false,
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "invalid",
|
||||
cliIP: net.IP{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
want: assert.False,
|
||||
name: "nil",
|
||||
cliIP: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
s.processDetermineLocal(dctx)
|
||||
|
||||
assert.Equal(t, tc.want, dctx.isLocalClient)
|
||||
tc.want(t, dctx.isLocalClient)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ type Server struct {
|
||||
localDomainSuffix string
|
||||
|
||||
ipset ipsetCtx
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
privateNets netutil.SubnetSet
|
||||
localResolvers *proxy.Proxy
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
recDetector *recursionDetector
|
||||
@@ -111,13 +111,13 @@ const defaultLocalDomainSuffix = ".lan."
|
||||
|
||||
// DNSCreateParams are parameters to create a new server.
|
||||
type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
SubnetDetector *aghnet.SubnetDetector
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
// domainNameToSuffix converts a domain name into a local domain suffix.
|
||||
@@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
subnetDetector: p.SubnetDetector,
|
||||
privateNets: p.PrivateNets,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
@@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
}
|
||||
|
||||
resolver := s.internalProxy
|
||||
if s.subnetDetector.IsLocallyServedNetwork(ip) {
|
||||
if s.privateNets.Contains(ip) {
|
||||
if !s.conf.UsePrivateRDNS {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
@@ -69,14 +70,11 @@ func createTestServer(
|
||||
f := filtering.New(filterConf, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -913,15 +906,10 @@ func TestRewrite(t *testing.T) {
|
||||
f := filtering.New(c, nil)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1028,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
SubnetDetector: snd,
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1105,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
}, nil)
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var snd *aghnet.SubnetDetector
|
||||
snd, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1227,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv.conf.ResolveClients = true
|
||||
srv.conf.UsePrivateRDNS = true
|
||||
|
||||
var err error
|
||||
srv.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f.SetEnabled(true)
|
||||
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, snd)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
SubnetDetector: snd,
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -167,7 +166,7 @@ func (req *dnsConfig) checkBootstrap() (err error) {
|
||||
}
|
||||
|
||||
// validate returns an error if any field of req is invalid.
|
||||
func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||
if req.Upstreams != nil {
|
||||
err = ValidateUpstreams(*req.Upstreams)
|
||||
if err != nil {
|
||||
@@ -176,7 +175,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) {
|
||||
}
|
||||
|
||||
if req.LocalPTRUpstreams != nil {
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd)
|
||||
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating private upstream servers: %w", err)
|
||||
}
|
||||
@@ -224,7 +223,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = req.validate(s.subnetDetector)
|
||||
err = req.validate(s.privateNets)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -350,17 +349,6 @@ func IsCommentOrEmpty(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#'
|
||||
}
|
||||
|
||||
// LocalNetChecker is used to check if the IP address belongs to a local
|
||||
// network.
|
||||
type LocalNetChecker interface {
|
||||
// IsLocallyServedNetwork returns true if ip is contained in any of address
|
||||
// registries defined by RFC 6303.
|
||||
IsLocallyServedNetwork(ip net.IP) (ok bool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil)
|
||||
|
||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
||||
// configuration or nil if it can't be built.
|
||||
//
|
||||
@@ -422,8 +410,8 @@ func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
// a locally-served network. lnc must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) {
|
||||
// a locally-served network. privateNets must not be nil.
|
||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||
conf, err := newUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -444,7 +432,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro
|
||||
continue
|
||||
}
|
||||
|
||||
if !lnc.IsLocallyServedNetwork(subnet.IP) {
|
||||
if !privateNets.Contains(subnet.IP) {
|
||||
errs = append(
|
||||
errs,
|
||||
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -410,8 +411,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
snd, err := aghnet.NewSubnetDetector()
|
||||
require.NoError(t, err)
|
||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -452,7 +452,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
set := []string{"192.168.0.1", tc.u}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = ValidateUpstreamsPrivate(set, snd)
|
||||
err := ValidateUpstreamsPrivate(set, ss)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user