Start using dnsproxy
This commit is contained in:
@@ -2,17 +2,24 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
@@ -26,66 +33,18 @@ import (
|
||||
//
|
||||
// The zero Server is empty and ready for use.
|
||||
type Server struct {
|
||||
udpListen *net.UDPConn
|
||||
dnsProxy *proxy.Proxy // DNS proxy instance
|
||||
|
||||
dnsFilter *dnsfilter.Dnsfilter
|
||||
|
||||
cache cache
|
||||
|
||||
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
|
||||
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
|
||||
|
||||
sync.RWMutex
|
||||
ServerConfig
|
||||
}
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
// uncomment this block to have tracing of locks
|
||||
/*
|
||||
func (s *Server) Lock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.Lock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) RLock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.RLock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) Unlock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.Unlock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
func (s *Server) RUnlock() {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
file, line := f.FileLine(pc[0])
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> in progress\n", path.Base(file), line, path.Base(f.Name()))
|
||||
s.RWMutex.RUnlock()
|
||||
fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> done\n", path.Base(file), line, path.Base(f.Name()))
|
||||
}
|
||||
*/
|
||||
|
||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||
type FilteringConfig struct {
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||
Ratelimit int `yaml:"ratelimit"`
|
||||
@@ -96,11 +55,12 @@ type FilteringConfig struct {
|
||||
dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
|
||||
// ServerConfig represents server configuration.
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *)
|
||||
Upstreams []Upstream
|
||||
Filters []dnsfilter.Filter
|
||||
UDPListenAddr *net.UDPAddr // UDP listen address
|
||||
Upstreams []upstream.Upstream // Configured upstreams
|
||||
Filters []dnsfilter.Filter // A list of filters to use
|
||||
|
||||
FilteringConfig
|
||||
}
|
||||
@@ -109,94 +69,40 @@ type ServerConfig struct {
|
||||
var defaultValues = ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{Port: 53},
|
||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||
Upstreams: []Upstream{
|
||||
//// dns over HTTPS
|
||||
// &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")},
|
||||
// &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")},
|
||||
|
||||
//// dns over TLS
|
||||
// &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")},
|
||||
// &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")},
|
||||
|
||||
//// plainDNS
|
||||
&plainDNS{boot: toBoot("8.8.8.8:53", "")},
|
||||
&plainDNS{boot: toBoot("8.8.4.4:53", "")},
|
||||
&plainDNS{boot: toBoot("1.1.1.1:53", "")},
|
||||
&plainDNS{boot: toBoot("1.0.0.1:53", "")},
|
||||
},
|
||||
}
|
||||
|
||||
//
|
||||
// packet loop
|
||||
//
|
||||
func (s *Server) packetLoop() {
|
||||
log.Printf("Entering packet handle loop")
|
||||
b := make([]byte, dns.MaxMsgSize)
|
||||
for {
|
||||
s.RLock()
|
||||
conn := s.udpListen
|
||||
s.RUnlock()
|
||||
if conn == nil {
|
||||
log.Printf("udp socket has disappeared, exiting loop")
|
||||
break
|
||||
}
|
||||
n, addr, err := conn.ReadFrom(b)
|
||||
// documentation says to handle the packet even if err occurs, so do that first
|
||||
if n > 0 {
|
||||
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
|
||||
// we need the contents to survive the call because we're handling them in goroutine
|
||||
p := make([]byte, n)
|
||||
copy(p, b)
|
||||
go s.handlePacket(p, addr, conn) // ignore errors
|
||||
}
|
||||
if err != nil {
|
||||
if isConnClosed(err) {
|
||||
log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||
// don't try to nullify s.udpListen here, because s.udpListen could be already re-bound to listen
|
||||
break
|
||||
}
|
||||
log.Printf("Got error when reading from udp listen: %s", err)
|
||||
func init() {
|
||||
defaultDNS := []string{"8.8.8.8:53", "8.8.4.4:53"}
|
||||
|
||||
defaultUpstreams := make([]upstream.Upstream, 0)
|
||||
for _, addr := range defaultDNS {
|
||||
u, err := upstream.AddressToUpstream(addr, "")
|
||||
if err == nil {
|
||||
defaultUpstreams = append(defaultUpstreams, u)
|
||||
}
|
||||
}
|
||||
defaultValues.Upstreams = defaultUpstreams
|
||||
}
|
||||
|
||||
//
|
||||
// Control functions
|
||||
//
|
||||
|
||||
// Start starts the DNS server
|
||||
func (s *Server) Start(config *ServerConfig) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.startInternal(config)
|
||||
}
|
||||
|
||||
// startInternal starts without locking
|
||||
func (s *Server) startInternal(config *ServerConfig) error {
|
||||
if config != nil {
|
||||
s.ServerConfig = *config
|
||||
}
|
||||
// TODO: handle being called Start() second time after Stop()
|
||||
if s.udpListen == nil {
|
||||
log.Printf("Creating UDP socket")
|
||||
var err error
|
||||
addr := s.UDPListenAddr
|
||||
if addr == nil {
|
||||
addr = defaultValues.UDPListenAddr
|
||||
}
|
||||
s.udpListen, err = net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
s.udpListen = nil
|
||||
return errorx.Decorate(err, "Couldn't listen to UDP socket")
|
||||
}
|
||||
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||
}
|
||||
|
||||
if s.dnsFilter == nil {
|
||||
log.Printf("Creating dnsfilter")
|
||||
s.dnsFilter = dnsfilter.New(&s.Config)
|
||||
// add rules only if they are enabled
|
||||
if s.FilteringEnabled {
|
||||
// TODO: Handle error
|
||||
s.dnsFilter.AddRules(s.Filters)
|
||||
}
|
||||
}
|
||||
@@ -214,22 +120,55 @@ func (s *Server) Start(config *ServerConfig) error {
|
||||
go statsRotator()
|
||||
})
|
||||
|
||||
go s.packetLoop()
|
||||
// TODO: Add TCPListenAddr
|
||||
proxyConfig := proxy.Config{
|
||||
UDPListenAddr: s.UDPListenAddr,
|
||||
Ratelimit: s.Ratelimit,
|
||||
RatelimitWhitelist: s.RatelimitWhitelist,
|
||||
RefuseAny: s.RefuseAny,
|
||||
CacheEnabled: true,
|
||||
Upstreams: s.Upstreams,
|
||||
Handler: s,
|
||||
}
|
||||
|
||||
return nil
|
||||
if proxyConfig.UDPListenAddr == nil {
|
||||
proxyConfig.UDPListenAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
|
||||
if len(proxyConfig.Upstreams) == 0 {
|
||||
proxyConfig.Upstreams = defaultValues.Upstreams
|
||||
}
|
||||
|
||||
// TODO: Don't let call Start the second time
|
||||
// Initialize the DNS proxy
|
||||
s.dnsProxy = &proxy.Proxy{Config: proxyConfig}
|
||||
|
||||
err = s.dnsProxy.Start()
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop stops the DNS server
|
||||
func (s *Server) Stop() error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if s.udpListen != nil {
|
||||
err := s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
return s.stopInternal()
|
||||
}
|
||||
|
||||
// stopInternal stops without locking
|
||||
func (s *Server) stopInternal() error {
|
||||
if s.dnsProxy != nil {
|
||||
err := s.dnsProxy.Stop()
|
||||
s.dnsProxy = nil
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
return errorx.Decorate(err, "could not stop the DNS server properly")
|
||||
}
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil {
|
||||
s.dnsFilter.Destroy()
|
||||
s.dnsFilter = nil
|
||||
}
|
||||
|
||||
// flush remainder to file
|
||||
logBufferLock.Lock()
|
||||
flushBuffer := logBuffer
|
||||
@@ -244,283 +183,55 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns true if the DNS server is running
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.RLock()
|
||||
isRunning := true
|
||||
if s.udpListen == nil {
|
||||
if s.dnsProxy == nil {
|
||||
isRunning = false
|
||||
}
|
||||
s.RUnlock()
|
||||
return isRunning
|
||||
}
|
||||
|
||||
//
|
||||
// Server reconfigure
|
||||
//
|
||||
|
||||
func (s *Server) reconfigureListenAddr(new ServerConfig) error {
|
||||
oldAddr := s.UDPListenAddr
|
||||
if oldAddr == nil {
|
||||
oldAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
newAddr := new.UDPListenAddr
|
||||
if newAddr == nil {
|
||||
newAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
if newAddr.Port == 0 {
|
||||
return errorx.IllegalArgument.New("new port cannot be 0")
|
||||
}
|
||||
if reflect.DeepEqual(oldAddr, newAddr) {
|
||||
// do nothing, the addresses are exactly the same
|
||||
log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebind, using a strategy:
|
||||
// * if ports are different, bind new first, then close old
|
||||
// * if ports are same, close old first, then bind new
|
||||
var newListen *net.UDPConn
|
||||
var err error
|
||||
if oldAddr.Port != newAddr.Port {
|
||||
log.Printf("Rebinding -- ports are different so bind first then close")
|
||||
newListen, err = net.ListenUDP("udp", newAddr)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||
}
|
||||
s.Lock()
|
||||
if s.udpListen != nil {
|
||||
err = s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
}
|
||||
s.Unlock()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
}
|
||||
} else {
|
||||
log.Printf("Rebinding -- ports are same so close first then bind")
|
||||
s.Lock()
|
||||
if s.udpListen != nil {
|
||||
err = s.udpListen.Close()
|
||||
s.udpListen = nil
|
||||
}
|
||||
s.Unlock()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't close UDP listening socket")
|
||||
}
|
||||
newListen, err = net.ListenUDP("udp", newAddr)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't bind to %v", newAddr)
|
||||
}
|
||||
}
|
||||
// Reconfigure applies the new configuration to the DNS server
|
||||
func (s *Server) Reconfigure(config *ServerConfig) error {
|
||||
s.Lock()
|
||||
s.udpListen = newListen
|
||||
s.UDPListenAddr = new.UDPListenAddr
|
||||
s.Unlock()
|
||||
log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr)
|
||||
defer s.Unlock()
|
||||
|
||||
go s.packetLoop() // the old one has quit, use new one
|
||||
log.Print("Start reconfiguring the server")
|
||||
err := s.stopInternal()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "could not reconfigure the server")
|
||||
}
|
||||
err = s.startInternal(config)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "could not reconfigure the server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureBlockedResponseTTL(new ServerConfig) {
|
||||
newVal := new.BlockedResponseTTL
|
||||
if newVal == 0 {
|
||||
newVal = defaultValues.BlockedResponseTTL
|
||||
}
|
||||
oldVal := s.BlockedResponseTTL
|
||||
if oldVal == 0 {
|
||||
oldVal = defaultValues.BlockedResponseTTL
|
||||
}
|
||||
if newVal != oldVal {
|
||||
s.BlockedResponseTTL = new.BlockedResponseTTL
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureUpstreams(new ServerConfig) {
|
||||
newVal := new.Upstreams
|
||||
if len(newVal) == 0 {
|
||||
newVal = defaultValues.Upstreams
|
||||
}
|
||||
oldVal := s.Upstreams
|
||||
if len(oldVal) == 0 {
|
||||
oldVal = defaultValues.Upstreams
|
||||
}
|
||||
if reflect.DeepEqual(newVal, oldVal) {
|
||||
// they're exactly the same, do nothing
|
||||
return
|
||||
}
|
||||
s.Upstreams = new.Upstreams
|
||||
}
|
||||
|
||||
func (s *Server) reconfigureFiltering(new ServerConfig) {
|
||||
newFilters := new.Filters
|
||||
if len(newFilters) == 0 {
|
||||
newFilters = defaultValues.Filters
|
||||
}
|
||||
oldFilters := s.Filters
|
||||
if len(oldFilters) == 0 {
|
||||
oldFilters = defaultValues.Filters
|
||||
}
|
||||
|
||||
needUpdate := false
|
||||
if !reflect.DeepEqual(newFilters, oldFilters) {
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(new.FilteringConfig, s.FilteringConfig) {
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if !needUpdate {
|
||||
// nothing to do, everything is same
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: instead of creating new dnsfilter, change existing one's settings and filters
|
||||
dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental
|
||||
|
||||
// add rules only if they are enabled
|
||||
if new.FilteringEnabled {
|
||||
dnsFilter.AddRules(newFilters)
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
oldDNSFilter := s.dnsFilter
|
||||
s.dnsFilter = dnsFilter
|
||||
s.FilteringConfig = new.FilteringConfig
|
||||
s.Unlock()
|
||||
|
||||
oldDNSFilter.Destroy()
|
||||
}
|
||||
|
||||
func (s *Server) Reconfigure(new ServerConfig) error {
|
||||
s.reconfigureBlockedResponseTTL(new)
|
||||
s.reconfigureUpstreams(new)
|
||||
s.reconfigureFiltering(new)
|
||||
|
||||
err := s.reconfigureListenAddr(new)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// packet handling functions
|
||||
//
|
||||
|
||||
// handlePacketInternal processes the incoming packet bytes and returns with an optional response packet.
|
||||
//
|
||||
// If an empty dns.Msg is returned, do not try to send anything back to client, otherwise send contents of dns.Msg.
|
||||
//
|
||||
// If an error is returned, log it, don't try to generate data based on that error.
|
||||
func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDPConn) (*dns.Msg, *dnsfilter.Result, Upstream, error) {
|
||||
// log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p)
|
||||
//
|
||||
// DNS packet byte format is valid
|
||||
//
|
||||
// any errors below here require a response to client
|
||||
// log.Printf("Unpacked: %v", msg.String())
|
||||
if len(msg.Question) != 1 {
|
||||
log.Printf("Got invalid number of questions: %v", len(msg.Question))
|
||||
return s.genServerFailure(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
|
||||
return s.genNotImpl(msg), nil, nil, nil
|
||||
}
|
||||
|
||||
// we need upstream to resolve A records
|
||||
upstream := s.chooseUpstream()
|
||||
|
||||
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||
var res dnsfilter.Result
|
||||
var err error
|
||||
if s.ProtectionEnabled {
|
||||
res, err = s.dnsFilter.CheckHost(host)
|
||||
if err != nil {
|
||||
log.Printf("dnsfilter failed to check host '%s': %s", host, err)
|
||||
return s.genServerFailure(msg), &res, nil, err
|
||||
} else if res.IsFiltered {
|
||||
log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
|
||||
switch res.Reason {
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
return s.genArecord(msg, safeBrowsingBlockHost, upstream), &res, nil, nil
|
||||
case dnsfilter.FilteredParental:
|
||||
return s.genArecord(msg, parentalBlockHost, upstream), &res, nil, nil
|
||||
}
|
||||
return s.genNXDomain(msg), &res, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
val, ok := s.cache.Get(msg)
|
||||
if ok && val != nil {
|
||||
return val, &res, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: replace with single-socket implementation
|
||||
reply, err := upstream.Exchange(msg)
|
||||
if err != nil {
|
||||
log.Printf("talking to upstream failed for host '%s': %s", host, err)
|
||||
return s.genServerFailure(msg), &res, upstream, err
|
||||
}
|
||||
if reply == nil {
|
||||
log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String())
|
||||
return s.genServerFailure(msg), &res, upstream, nil
|
||||
}
|
||||
|
||||
s.cache.Set(reply)
|
||||
|
||||
return reply, &res, upstream, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||
// ServeDNS filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) ServeDNS(d *proxy.DNSContext, next proxy.Handler) error {
|
||||
start := time.Now()
|
||||
ip, _, err := net.SplitHostPort(addr.String())
|
||||
|
||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
||||
res, err := s.filterDNSRequest(d)
|
||||
if err != nil {
|
||||
log.Printf("Failed to split %v into host/port: %s", addr, err)
|
||||
// not a fatal error, move on
|
||||
return err
|
||||
}
|
||||
|
||||
// ratelimit based on IP only, protects CPU cycles and outbound connections
|
||||
if s.isRatelimited(ip) {
|
||||
// log.Printf("Ratelimiting %s based on IP only", ip)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
err = msg.Unpack(p)
|
||||
if err != nil {
|
||||
log.Printf("got invalid DNS packet: %s", err)
|
||||
return // do nothing
|
||||
}
|
||||
|
||||
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
|
||||
|
||||
if reply != nil {
|
||||
// ratelimit based on reply size now
|
||||
replysize := reply.Len()
|
||||
if s.isRatelimitedForReply(ip, replysize) {
|
||||
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
|
||||
return // do nothing, don't reply, we got ratelimited
|
||||
}
|
||||
|
||||
// we're good to respond
|
||||
rerr := s.respond(reply, addr, conn)
|
||||
if rerr != nil {
|
||||
log.Printf("Couldn't respond to UDP packet: %s", err)
|
||||
if d.Res == nil {
|
||||
// request was not filtered so let it be processed further
|
||||
err = next.ServeDNS(d, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// query logging and stats counters
|
||||
//
|
||||
|
||||
shouldLog := true
|
||||
msg := d.Req
|
||||
|
||||
// don't log ANY request if refuseAny is enabled
|
||||
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
|
||||
@@ -530,35 +241,64 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
|
||||
if s.QueryLogEnabled && shouldLog {
|
||||
elapsed := time.Since(start)
|
||||
upstreamAddr := ""
|
||||
if upstream != nil {
|
||||
upstreamAddr = upstream.Address()
|
||||
if d.Upstream != nil {
|
||||
upstreamAddr = d.Upstream.Address()
|
||||
}
|
||||
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
|
||||
logRequest(msg, d.Res, res, elapsed, d.Addr.String(), upstreamAddr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//
|
||||
// packet sending functions
|
||||
//
|
||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
msg := d.Req
|
||||
host := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
|
||||
func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error {
|
||||
// log.Printf("Replying to %s with %s", addr, resp)
|
||||
resp.Compress = true
|
||||
bytes, err := resp.Pack()
|
||||
s.RLock()
|
||||
protectionEnabled := s.ProtectionEnabled
|
||||
dnsFilter := s.dnsFilter
|
||||
s.RUnlock()
|
||||
|
||||
if !protectionEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var res dnsfilter.Result
|
||||
var err error
|
||||
|
||||
res, err = dnsFilter.CheckHost(host)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't convert message into wire format")
|
||||
// Return immediately if there's an error
|
||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||
} else if res.IsFiltered {
|
||||
log.Debugf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
|
||||
d.Res = s.genDNSFilterMessage(d, &res)
|
||||
}
|
||||
n, err := conn.WriteTo(bytes, addr)
|
||||
if n == 0 && isConnClosed(err) {
|
||||
return err
|
||||
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// genDNSFilterMessage generates a DNS message corresponding to the filtering result
|
||||
func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg {
|
||||
m := d.Req
|
||||
|
||||
if m.Question[0].Qtype != dns.TypeA {
|
||||
return s.genNXDomain(m)
|
||||
}
|
||||
if n != len(bytes) {
|
||||
return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes))
|
||||
|
||||
switch result.Reason {
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
return s.genBlockedHost(m, safeBrowsingBlockHost, d.Upstream)
|
||||
case dnsfilter.FilteredParental:
|
||||
return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
|
||||
default:
|
||||
if result.Ip != nil {
|
||||
return s.genARecord(m, result.Ip)
|
||||
}
|
||||
|
||||
return s.genNXDomain(m)
|
||||
}
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "WriteTo() returned error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
@@ -568,29 +308,19 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
|
||||
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
resp.SetRcode(request, dns.RcodeNotImplemented)
|
||||
resp.RecursionAvailable = true
|
||||
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
|
||||
resp.SetReply(request)
|
||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.BlockedResponseTTL, ip.String()))
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't generate A record for up replacement host '%s': %s", ip.String(), err)
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genArecord(request *dns.Msg, newAddr string, upstream Upstream) *dns.Msg {
|
||||
addr := net.ParseIP(newAddr)
|
||||
if addr != nil {
|
||||
// this is an IP address, return it
|
||||
resp := dns.Msg{}
|
||||
resp.SetReply(request)
|
||||
answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.BlockedResponseTTL, newAddr))
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate A record for up replacement host '%s': %s", newAddr, err)
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, upstream upstream.Upstream) *dns.Msg {
|
||||
// look up the hostname, TODO: cache
|
||||
replReq := dns.Msg{}
|
||||
replReq.SetQuestion(dns.Fqdn(newAddr), request.Question[0].Qtype)
|
||||
|
||||
Reference in New Issue
Block a user