Pull request 2087: AG-27616-upd-proxy-ratelimit-whitelist

Squashed commit of the following:

commit 099a2eb11609a07a1cb72d9e15da3e668042de1d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Dec 5 17:25:49 2023 +0300

    all: upd proxy

commit db07130df80ed06b867f6ce6878908b1eb93a934
Merge: 9e6e8e7cf 75cb9d412
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Dec 5 14:44:44 2023 +0300

    Merge branch 'master' into AG-27616-upd-proxy-ratelimit-whitelist

commit 9e6e8e7cfc80507cff81761dd3964cf7777ac58b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Nov 29 19:46:17 2023 +0300

    all: imp tests

commit e753bb53880c2a0791d97079a12960e0b1d667ed
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Nov 29 13:35:21 2023 +0300

    all: upd proxy ratelimit whitelist
This commit is contained in:
Stanislav Chzhen
2023-12-05 17:43:50 +03:00
parent 75cb9d412a
commit 99af7f46de
13 changed files with 82 additions and 135 deletions

View File

@@ -67,7 +67,7 @@ type Config struct {
RatelimitSubnetLenIPv6 int `yaml:"ratelimit_subnet_len_ipv6"`
// RatelimitWhitelist is the list of whitelisted client IP addresses.
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
RatelimitWhitelist []netip.Addr `yaml:"ratelimit_whitelist"`
// RefuseAny, if true, refuse ANY requests.
RefuseAny bool `yaml:"refuse_any"`
@@ -298,24 +298,24 @@ type ServerConfig struct {
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
srvConf := s.conf
conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetMaskIPv4: net.CIDRMask(srvConf.RatelimitSubnetLenIPv4, netutil.IPv4BitLen),
RatelimitSubnetMaskIPv6: net.CIDRMask(srvConf.RatelimitSubnetLenIPv6, netutil.IPv6BitLen),
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: srvConf.TrustedProxies,
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: srvConf.TrustedProxies,
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler,
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
}
if srvConf.EDNSClientSubnet.UseCustom {

View File

@@ -299,7 +299,7 @@ func (s *Server) WriteDiskConfig(c *Config) {
sc := s.conf.Config
*c = sc
c.RatelimitWhitelist = stringutil.CloneSlice(sc.RatelimitWhitelist)
c.RatelimitWhitelist = slices.Clone(sc.RatelimitWhitelist)
c.BootstrapDNS = stringutil.CloneSlice(sc.BootstrapDNS)
c.FallbackDNS = stringutil.CloneSlice(sc.FallbackDNS)
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)

View File

@@ -54,13 +54,10 @@ const (
testMessagesCount = 10
)
// testClientAddr is the common net.Addr for tests.
// testClientAddrPort is the common net.Addr for tests.
//
// TODO(a.garipov): Use more.
var testClientAddr net.Addr = &net.TCPAddr{
IP: net.IP{1, 2, 3, 4},
Port: 12345,
}
var testClientAddrPort = netip.MustParseAddrPort("1.2.3.4:12345")
func startDeferStop(t *testing.T, s *Server) {
t.Helper()

View File

@@ -10,7 +10,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
@@ -28,8 +27,7 @@ func (s *Server) beforeRequestHandler(
return false, fmt.Errorf("getting clientid: %w", err)
}
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
@@ -60,8 +58,7 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
setts = s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil {
addrPort := netutil.NetAddrToAddrPort(dctx.proxyCtx.Addr)
s.conf.FilterHandler(addrPort.Addr(), dctx.clientID, setts)
s.conf.FilterHandler(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts)
}
return setts

View File

@@ -187,7 +187,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
dctx := &proxy.DNSContext{
Proto: proxy.ProtoUDP,
Req: tc.req,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
Addr: testClientAddrPort,
}
t.Run(tc.name, func(t *testing.T) {
@@ -326,7 +326,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1},
Addr: testClientAddrPort,
}
dctx := &dnsContext{

View File

@@ -52,7 +52,7 @@ type jsonDNSConfig struct {
RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"`
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
RatelimitWhitelist *[]string `json:"ratelimit_whitelist"`
RatelimitWhitelist *[]netip.Addr `json:"ratelimit_whitelist"`
// BlockingMode defines the way blocked responses are constructed.
BlockingMode *filtering.BlockingMode `json:"blocking_mode"`
@@ -129,7 +129,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
ratelimit := s.conf.Ratelimit
ratelimitSubnetLenIPv4 := s.conf.RatelimitSubnetLenIPv4
ratelimitSubnetLenIPv6 := s.conf.RatelimitSubnetLenIPv6
ratelimitWhitelist := stringutil.CloneSliceOrEmpty(s.conf.RatelimitWhitelist)
ratelimitWhitelist := append([]netip.Addr{}, s.conf.RatelimitWhitelist...)
customIP := s.conf.EDNSClientSubnet.CustomIP
enableEDNSClientSubnet := s.conf.EDNSClientSubnet.Enabled
@@ -291,12 +291,6 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return err
}
err = req.checkRatelimitWhitelist()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = req.checkBlockingMode()
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -405,21 +399,6 @@ func checkInclusion(ptr *int, minN, maxN int) (err error) {
return nil
}
// checkRatelimitWhitelist returns an error if any of IP addresses is invalid.
func (req *jsonDNSConfig) checkRatelimitWhitelist() (err error) {
if req.RatelimitWhitelist == nil {
return nil
}
for i, ipStr := range *req.RatelimitWhitelist {
if _, err = netip.ParseAddr(ipStr); err != nil {
return fmt.Errorf("ratelimit whitelist: at index %d: %w", i, err)
}
}
return nil
}
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := &jsonDNSConfig{}

View File

@@ -195,9 +195,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
name: "ratelimit_subnet_len",
wantSet: "",
}, {
name: "ratelimit_whitelist_not_ip",
wantSet: `validating dns config: ratelimit whitelist: at index 1: ParseAddr("not.ip"): ` +
`unexpected character (at "not.ip")`,
name: "ratelimit_whitelist_not_ip",
wantSet: `decoding request: ParseAddr("not.ip"): unexpected character (at "not.ip")`,
}, {
name: "edns_cs_enabled",
wantSet: "",

View File

@@ -191,7 +191,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
defer log.Debug("dnsforward: finished processing initial")
pctx := dctx.proxyCtx
s.processClientIP(pctx.Addr)
s.processClientIP(pctx.Addr.Addr())
q := pctx.Req.Question[0]
qt := q.Qtype
@@ -228,9 +228,8 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
}
// processClientIP sends the client IP address to s.addrProc, if needed.
func (s *Server) processClientIP(addr net.Addr) {
clientIP := netutil.NetAddrToAddrPort(addr).Addr()
if clientIP == (netip.Addr{}) {
func (s *Server) processClientIP(addr netip.Addr) {
if !addr.IsValid() {
log.Info("dnsforward: warning: bad client addr %q", addr)
return
@@ -241,7 +240,7 @@ func (s *Server) processClientIP(addr net.Addr) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
s.addrProc.Process(clientIP)
s.addrProc.Process(addr)
}
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
@@ -351,12 +350,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess
var ip net.IP
if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc
}
dctx.isLocalClient = s.privateNets.Contains(ip)
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr().AsSlice())
return rc
}
@@ -831,12 +825,12 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
if pctx.Addr == nil || s.conf.ClientsContainer == nil {
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
return
}
// Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)

View File

@@ -97,14 +97,14 @@ func TestServer_ProcessInitial(t *testing.T) {
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: createTestMessageWithType(tc.target, tc.qType),
Addr: testClientAddr,
Addr: testClientAddrPort,
RequestID: 1234,
},
}
gotRC := s.processInitial(dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, netutil.NetAddrToAddrPort(testClientAddr).Addr(), gotAddr)
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
if tc.wantRCode > 0 {
gotResp := dctx.proxyCtx.Res
@@ -201,7 +201,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
Proto: proxy.ProtoUDP,
Req: tc.req,
Res: resp,
Addr: testClientAddr,
Addr: testClientAddrPort,
},
}
@@ -397,33 +397,27 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliIP net.IP
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliIP: net.IP{192, 168, 0, 1},
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliIP: net.IP{250, 249, 0, 1},
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliIP: net.IP{1, 2, 3, 4, 5},
}, {
want: assert.False,
name: "nil",
cliIP: nil,
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: &net.TCPAddr{
IP: tc.cliIP,
},
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
@@ -711,31 +705,31 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
name string
want string
question net.IP
cliIP net.IP
cliAddr netip.AddrPort
wantLen int
}{{
name: "from_local_to_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliIP: net.IP{192, 168, 10, 10},
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"),
wantLen: 1,
}, {
name: "from_external_for_local",
want: "",
question: net.IP{192, 168, 1, 1},
cliIP: net.IP{254, 253, 252, 251},
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"),
wantLen: 0,
}, {
name: "from_local_for_local",
want: "some.local-client.",
question: net.IP{192, 168, 1, 1},
cliIP: net.IP{192, 168, 1, 2},
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"),
wantLen: 1,
}, {
name: "from_external_for_external",
want: "host1.example.net.",
question: net.IP{254, 253, 252, 251},
cliIP: net.IP{254, 253, 252, 255},
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"),
wantLen: 1,
}}
@@ -747,9 +741,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP,
Req: req,
Addr: &net.TCPAddr{
IP: tc.cliIP,
},
Addr: tc.cliAddr,
}
t.Run(tc.name, func(t *testing.T) {
@@ -794,7 +786,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddr,
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}
dnsCtx = &dnsContext{

View File

@@ -10,9 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// Write Stats data and logs
@@ -25,13 +23,12 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
host := aghnet.NormalizeDomain(q.Name)
processingTime := time.Since(dctx.startTime)
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
ip = slices.Clone(ip)
ip := pctx.Addr.Addr().AsSlice()
s.anonymizer.Load()(ip)
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
ipStr := ip.String()
ipStr := pctx.Addr.Addr().String()
ids := []string{ipStr, dctx.clientID}
qt, cl := q.Qtype, q.Qclass

View File

@@ -1,7 +1,7 @@
package dnsforward
import (
"net"
"net/netip"
"testing"
"time"
@@ -65,7 +65,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name string
domain string
proto proxy.Proto
addr net.Addr
addr netip.AddrPort
clientID string
wantLogProto querylog.ClientProto
wantStatClient string
@@ -76,7 +76,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -87,7 +87,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_tls_clientid",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "cli42",
wantLogProto: querylog.ClientProtoDoT,
wantStatClient: "cli42",
@@ -98,7 +98,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_tls",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoT,
wantStatClient: "1.2.3.4",
@@ -109,7 +109,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_quic",
domain: domain,
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoQ,
wantStatClient: "1.2.3.4",
@@ -120,7 +120,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_https",
domain: domain,
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDoH,
wantStatClient: "1.2.3.4",
@@ -131,7 +131,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_dnscrypt",
domain: domain,
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: querylog.ClientProtoDNSCrypt,
wantStatClient: "1.2.3.4",
@@ -142,7 +142,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_filtered",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -153,7 +153,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_sb",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -164,7 +164,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_ss",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -175,7 +175,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_pc",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
addr: testClientAddrPort,
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
@@ -186,10 +186,10 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
name: "success_udp_pc_empty_fqdn",
domain: ".",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
addr: netip.MustParseAddrPort("4.3.2.1:1234"),
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.5",
wantStatClient: "4.3.2.1",
wantCode: resultCodeSuccess,
reason: filtering.FilteredParental,
wantStatResult: stats.RParental,