dnsforward -- implement ratelimit and refuseany
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
@@ -31,6 +32,8 @@ type Server struct {
|
||||
|
||||
cache cache
|
||||
|
||||
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
|
||||
|
||||
sync.RWMutex
|
||||
ServerConfig
|
||||
}
|
||||
@@ -76,9 +79,13 @@ func (s *Server) RUnlock() {
|
||||
*/
|
||||
|
||||
type FilteringConfig struct {
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||
Ratelimit int `yaml:"ratelimit"`
|
||||
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
||||
RefuseAny bool `yaml:"refuse_any"`
|
||||
|
||||
dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
@@ -92,6 +99,7 @@ type ServerConfig struct {
|
||||
FilteringConfig
|
||||
}
|
||||
|
||||
// if any of ServerConfig values are zero, then default values from below are used
|
||||
var defaultValues = ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{Port: 53},
|
||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||
@@ -413,6 +421,10 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
|
||||
return s.genServerFailure(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
|
||||
return s.genNotImpl(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host)
|
||||
@@ -450,16 +462,36 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
|
||||
|
||||
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||
start := time.Now()
|
||||
ip, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
||||
// not a fatal error, move on
|
||||
}
|
||||
|
||||
// ratelimit based on IP only, protects CPU cycles and outbound connections
|
||||
if s.isRatelimited(ip) {
|
||||
// log.Printf("Ratelimiting %s based on IP only", ip)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
err := msg.Unpack(p)
|
||||
err = msg.Unpack(p)
|
||||
if err != nil {
|
||||
log.Printf("got invalid DNS packet: %s", err)
|
||||
return // do nothing
|
||||
}
|
||||
|
||||
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
|
||||
|
||||
if reply != nil {
|
||||
// ratelimit based on reply size now
|
||||
replysize := reply.Len()
|
||||
if s.isRatelimitedForReply(ip, replysize) {
|
||||
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
// we're good to respond
|
||||
rerr := s.respond(reply, addr, conn)
|
||||
if rerr != nil {
|
||||
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||
@@ -467,16 +499,14 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||
}
|
||||
|
||||
// query logging and stats counters
|
||||
elapsed := time.Since(start)
|
||||
upstreamAddr := ""
|
||||
if upstream != nil {
|
||||
upstreamAddr = upstream.Address()
|
||||
if s.QueryLogEnabled {
|
||||
elapsed := time.Since(start)
|
||||
upstreamAddr := ""
|
||||
if upstream != nil {
|
||||
upstreamAddr = upstream.Address()
|
||||
}
|
||||
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
|
||||
}
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
||||
}
|
||||
logRequest(msg, reply, result, elapsed, host, upstreamAddr)
|
||||
}
|
||||
|
||||
//
|
||||
@@ -506,12 +536,22 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error
|
||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeServerFailure)
|
||||
resp.RecursionAvailable = true
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNotImplemented)
|
||||
resp.RecursionAvailable = true
|
||||
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNameError)
|
||||
resp.RecursionAvailable = true
|
||||
resp.Ns = s.genSOA(request)
|
||||
return &resp
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user