* dnsfilter: major refactoring
* dnsfilter is controlled by package home, not dnsforward * move HTTP handlers to dnsfilter/ * apply filtering settings without DNS server restart * use only 1 goroutine for filters update * apply new filters quickly (after they are ready to be used)
This commit is contained in:
@@ -3,7 +3,6 @@ package dnsforward
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -44,12 +43,6 @@ type Server struct {
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Stats
|
||||
|
||||
// How many times the server was started
|
||||
// While creating a dnsfilter object,
|
||||
// we use this value to set s.dnsFilter property only with the most recent settings.
|
||||
startCounter uint32
|
||||
dnsfilterCreatorChan chan dnsfilterCreatorParams
|
||||
|
||||
AllowedClients map[string]bool // IP addresses of whitelist clients
|
||||
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
|
||||
AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
||||
@@ -60,15 +53,11 @@ type Server struct {
|
||||
conf ServerConfig
|
||||
}
|
||||
|
||||
type dnsfilterCreatorParams struct {
|
||||
conf dnsfilter.Config
|
||||
filters map[int]string
|
||||
}
|
||||
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server {
|
||||
func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog querylog.QueryLog) *Server {
|
||||
s := &Server{}
|
||||
s.dnsFilter = dnsFilter
|
||||
s.stats = stats
|
||||
s.queryLog = queryLog
|
||||
return s
|
||||
@@ -76,6 +65,7 @@ func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server {
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.Lock()
|
||||
s.dnsFilter = nil
|
||||
s.stats = nil
|
||||
s.queryLog = nil
|
||||
s.Unlock()
|
||||
@@ -84,11 +74,8 @@ func (s *Server) Close() {
|
||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||
// The zero FilteringConfig is empty and ready for use.
|
||||
type FilteringConfig struct {
|
||||
// Create dnsfilter asynchronously.
|
||||
// Requests won't be filtered until dnsfilter is created.
|
||||
// If "restart" command is received while we're creating an old dnsfilter object,
|
||||
// we delay creation of the new object until the old one is created.
|
||||
AsyncStartup bool `yaml:"-"`
|
||||
// Filtering callback function
|
||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
@@ -116,8 +103,9 @@ type FilteringConfig struct {
|
||||
// Per-client settings can override this configuration.
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
|
||||
CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes)
|
||||
dnsfilter.Config `yaml:",inline"`
|
||||
CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes)
|
||||
|
||||
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
@@ -140,7 +128,6 @@ type ServerConfig struct {
|
||||
TCPListenAddr *net.TCPAddr // TCP listen address
|
||||
Upstreams []upstream.Upstream // Configured upstreams
|
||||
DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
|
||||
Filters []dnsfilter.Filter // A list of filters to use
|
||||
OnDNSRequest func(d *proxy.DNSContext)
|
||||
|
||||
FilteringConfig
|
||||
@@ -204,13 +191,18 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
|
||||
|
||||
// startInternal starts without locking
|
||||
func (s *Server) startInternal(config *ServerConfig) error {
|
||||
if s.dnsFilter != nil || s.dnsProxy != nil {
|
||||
if s.dnsProxy != nil {
|
||||
return errors.New("DNS server is already started")
|
||||
}
|
||||
|
||||
err := s.initDNSFilter(config)
|
||||
if err != nil {
|
||||
return err
|
||||
if config != nil {
|
||||
s.conf = *config
|
||||
}
|
||||
if len(s.conf.ParentalBlockHost) == 0 {
|
||||
s.conf.ParentalBlockHost = parentalBlockHost
|
||||
}
|
||||
if len(s.conf.SafeBrowsingBlockHost) == 0 {
|
||||
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
|
||||
}
|
||||
|
||||
proxyConfig := proxy.Config{
|
||||
@@ -228,7 +220,7 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
||||
AllServers: s.conf.AllServers,
|
||||
}
|
||||
|
||||
err = processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients)
|
||||
err := processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -269,97 +261,6 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
||||
return s.dnsProxy.Start()
|
||||
}
|
||||
|
||||
// Initializes the DNS filter
|
||||
func (s *Server) initDNSFilter(config *ServerConfig) error {
|
||||
if config != nil {
|
||||
s.conf = *config
|
||||
}
|
||||
|
||||
var filters map[int]string
|
||||
filters = nil
|
||||
if s.conf.FilteringEnabled {
|
||||
filters = make(map[int]string)
|
||||
for _, f := range s.conf.Filters {
|
||||
if f.ID == 0 {
|
||||
filters[int(f.ID)] = string(f.Data)
|
||||
} else {
|
||||
filters[int(f.ID)] = f.FilePath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.conf.ParentalBlockHost) == 0 {
|
||||
s.conf.ParentalBlockHost = parentalBlockHost
|
||||
}
|
||||
if len(s.conf.SafeBrowsingBlockHost) == 0 {
|
||||
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
|
||||
}
|
||||
|
||||
if s.conf.AsyncStartup {
|
||||
params := dnsfilterCreatorParams{
|
||||
conf: s.conf.Config,
|
||||
filters: filters,
|
||||
}
|
||||
s.startCounter++
|
||||
if s.startCounter == 1 {
|
||||
s.dnsfilterCreatorChan = make(chan dnsfilterCreatorParams, 1)
|
||||
go s.dnsfilterCreator()
|
||||
}
|
||||
|
||||
// remove all pending tasks
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case <-s.dnsfilterCreatorChan:
|
||||
//
|
||||
default:
|
||||
stop = true
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsfilterCreatorChan <- params
|
||||
} else {
|
||||
log.Debug("creating dnsfilter...")
|
||||
f := dnsfilter.New(&s.conf.Config, filters)
|
||||
if f == nil {
|
||||
return fmt.Errorf("could not initialize dnsfilter")
|
||||
}
|
||||
log.Debug("created dnsfilter")
|
||||
s.dnsFilter = f
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) dnsfilterCreator() {
|
||||
for {
|
||||
params := <-s.dnsfilterCreatorChan
|
||||
|
||||
s.Lock()
|
||||
counter := s.startCounter
|
||||
s.Unlock()
|
||||
|
||||
log.Debug("creating dnsfilter...")
|
||||
f := dnsfilter.New(¶ms.conf, params.filters)
|
||||
if f == nil {
|
||||
log.Error("could not initialize dnsfilter")
|
||||
continue
|
||||
}
|
||||
|
||||
set := false
|
||||
s.Lock()
|
||||
if counter == s.startCounter {
|
||||
s.dnsFilter = f
|
||||
set = true
|
||||
}
|
||||
s.Unlock()
|
||||
if set {
|
||||
log.Debug("created and activated dnsfilter")
|
||||
} else {
|
||||
log.Debug("created dnsfilter")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the DNS server
|
||||
func (s *Server) Stop() error {
|
||||
s.Lock()
|
||||
@@ -377,11 +278,6 @@ func (s *Server) stopInternal() error {
|
||||
}
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil {
|
||||
s.dnsFilter.Destroy()
|
||||
s.dnsFilter = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -607,33 +503,24 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
||||
var res dnsfilter.Result
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
|
||||
dnsFilter := s.dnsFilter
|
||||
|
||||
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
|
||||
return &dnsfilter.Result{}, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
clientAddr := ""
|
||||
if d.Addr != nil {
|
||||
clientAddr, _, _ = net.SplitHostPort(d.Addr.String())
|
||||
}
|
||||
|
||||
var setts dnsfilter.RequestFilteringSettings
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
setts.SafeSearchEnabled = s.conf.SafeSearchEnabled
|
||||
setts.SafeBrowsingEnabled = s.conf.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = s.conf.ParentalEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
s.conf.FilterHandler(clientAddr, &setts)
|
||||
}
|
||||
|
||||
res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
|
||||
req := d.Req
|
||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
|
||||
if err != nil {
|
||||
// Return immediately if there's an error
|
||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||
|
||||
Reference in New Issue
Block a user