Fix race conditions found by go's race detector
This commit is contained in:
@@ -55,6 +55,8 @@ type plug struct {
|
||||
ParentalBlockHost string
|
||||
QueryLogEnabled bool
|
||||
BlockedTTL uint32 // in seconds, default 3600
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
var defaultPlugin = plug{
|
||||
@@ -246,17 +248,21 @@ func (p *plug) parseEtcHosts(text string) bool {
|
||||
}
|
||||
|
||||
func (p *plug) onShutdown() error {
|
||||
p.Lock()
|
||||
p.d.Destroy()
|
||||
p.d = nil
|
||||
p.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *plug) onFinalShutdown() error {
|
||||
logBufferLock.Lock()
|
||||
err := flushToFile(logBuffer)
|
||||
if err != nil {
|
||||
log.Printf("failed to flush to file: %s", err)
|
||||
return err
|
||||
}
|
||||
logBufferLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -293,9 +299,11 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d
|
||||
}
|
||||
|
||||
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
|
||||
p.RLock()
|
||||
stats := p.d.GetStats()
|
||||
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
|
||||
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
|
||||
p.RUnlock()
|
||||
}
|
||||
|
||||
// Describe is called by prometheus handler to know stat types
|
||||
@@ -365,12 +373,12 @@ func (p *plug) genSOA(r *dns.Msg) []dns.RR {
|
||||
}
|
||||
Ns := "fake-for-negative-caching.adguard.com."
|
||||
|
||||
soa := defaultSOA
|
||||
soa := *defaultSOA
|
||||
soa.Hdr = header
|
||||
soa.Mbox = Mbox
|
||||
soa.Ns = Ns
|
||||
soa.Serial = uint32(time.Now().Unix())
|
||||
return []dns.RR{soa}
|
||||
soa.Serial = 100500 // faster than uint32(time.Now().Unix())
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
@@ -397,13 +405,17 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
|
||||
for _, question := range r.Question {
|
||||
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
||||
// is it a safesearch domain?
|
||||
p.RLock()
|
||||
if val, ok := p.d.SafeSearchDomain(host); ok {
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
p.RUnlock()
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
p.RUnlock()
|
||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
||||
}
|
||||
p.RUnlock()
|
||||
|
||||
// is it in hosts?
|
||||
if val, ok := p.hosts[host]; ok {
|
||||
@@ -425,11 +437,14 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
|
||||
}
|
||||
|
||||
// needs to be filtered instead
|
||||
p.RLock()
|
||||
result, err := p.d.CheckHost(host)
|
||||
if err != nil {
|
||||
log.Printf("plugin/dnsfilter: %s\n", err)
|
||||
p.RUnlock()
|
||||
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
p.RUnlock()
|
||||
|
||||
if result.IsFiltered {
|
||||
switch result.Reason {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdguardDNS/dnsfilter"
|
||||
@@ -23,6 +24,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
logBufferLock sync.RWMutex
|
||||
logBuffer []logEntry
|
||||
)
|
||||
|
||||
@@ -65,11 +67,13 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||
}
|
||||
var flushBuffer []logEntry
|
||||
|
||||
logBufferLock.Lock()
|
||||
logBuffer = append(logBuffer, entry)
|
||||
if len(logBuffer) >= logBufferCap {
|
||||
flushBuffer = logBuffer
|
||||
logBuffer = nil
|
||||
}
|
||||
logBufferLock.Unlock()
|
||||
if len(flushBuffer) > 0 {
|
||||
// write to file
|
||||
// do it in separate goroutine -- we are stalling DNS response this whole time
|
||||
@@ -81,7 +85,9 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: fetch values from disk if len(logBuffer) < queryLogSize
|
||||
// TODO: cache output
|
||||
logBufferLock.RLock()
|
||||
values := logBuffer
|
||||
logBufferLock.RUnlock()
|
||||
var data = []map[string]interface{}{}
|
||||
for _, entry := range values {
|
||||
var q *dns.Msg
|
||||
|
||||
Reference in New Issue
Block a user