* dnsfilter: use fastcache instead of gcache

This commit is contained in:
Simon Zolin
2019-07-23 17:14:13 +03:00
parent 81303b5db7
commit 6f51df7d2e
3 changed files with 54 additions and 55 deletions

View File

@@ -5,8 +5,8 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
@@ -16,6 +16,7 @@ import (
"sync/atomic"
"time"
"github.com/VictoriaMetrics/fastcache"
"github.com/joomcode/errorx"
"github.com/AdguardTeam/dnsproxy/upstream"
@@ -135,9 +136,9 @@ const (
type dnsFilterContext struct {
stats Stats
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache gcache.Cache
parentalCache gcache.Cache
safeSearchCache gcache.Cache
safebrowsingCache *fastcache.Cache
parentalCache *fastcache.Cache
safeSearchCache *fastcache.Cache
}
var gctx dnsFilterContext // global dnsfilter context
@@ -233,32 +234,33 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Res
return Result{}, nil
}
func getCachedReason(cache gcache.Cache, host string) (result Result, isFound bool, err error) {
func setCacheResult(cache *fastcache.Cache, host string, res Result) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
_ = enc.Encode(res)
cache.Set([]byte(host), buf.Bytes())
log.Debug("Stored in cache %p: %s => [%d]", cache, host, buf.Len())
}
func getCachedResult(cache *fastcache.Cache, host string) (result Result, isFound bool) {
isFound = false // not found yet
// get raw value
rawValue, err := cache.Get(host)
if err == gcache.KeyNotFoundError {
// not a real error, just not found
err = nil
return
}
if err != nil {
// real error
return
rawValue := cache.Get(nil, []byte(host))
if len(rawValue) == 0 {
return Result{}, false
}
// since it can be something else, validate that it belongs to proper type
cachedValue, ok := rawValue.(Result)
if !ok {
// this is not our type -- error
text := "SHOULD NOT HAPPEN: entry with invalid type was found in lookup cache"
log.Println(text)
err = errors.New(text)
return
var buf bytes.Buffer
buf.Write(rawValue)
dec := gob.NewDecoder(&buf)
cachedValue := Result{}
err := dec.Decode(&cachedValue)
if err != nil {
log.Debug("gob.Decode(): %s", err)
return Result{}, false
}
isFound = ok
return cachedValue, isFound, err
return cachedValue, true
}
// for each dot, hash it and add it to string
@@ -304,17 +306,13 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
}
// Check cache. Return cached result if it was found
cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
if isFound {
atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("%s: found in SafeSearch cache", host)
return cachedValue, nil
}
if err != nil {
return Result{}, err
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
return Result{}, nil
@@ -323,11 +321,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
err = gctx.safeSearchCache.Set(host, res)
if err != nil {
return Result{}, nil
}
setCacheResult(gctx.safeSearchCache, host, res)
return res, nil
}
@@ -350,10 +344,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
}
// Cache result
err = gctx.safeSearchCache.Set(host, res)
if err != nil {
return Result{}, nil
}
setCacheResult(gctx.safeSearchCache, host, res)
return res, nil
}
@@ -456,20 +447,17 @@ type formatHandler func(hashparam string) string
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache *fastcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
// if host ends with a dot, trim it
host = strings.ToLower(strings.Trim(host, "."))
// check cache
cachedValue, isFound, err := getCachedReason(cache, host)
cachedValue, isFound := getCachedResult(cache, host)
if isFound {
atomic.AddUint64(&lookupstats.CacheHits, 1)
log.Tracef("%s: found in the lookup cache", host)
log.Tracef("%s: found in the lookup cache %p", host, cache)
return cachedValue, nil
}
if err != nil {
return Result{}, err
}
// convert hostname to hash parameters
hashparam, hashes := hostnameToHashParam(host, hashparamNeedSlash)
@@ -502,10 +490,7 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc
switch {
case resp.StatusCode == 204:
// empty result, save cache
err = cache.Set(host, Result{})
if err != nil {
return Result{}, err
}
setCacheResult(cache, host, Result{})
return Result{}, nil
case resp.StatusCode != 200:
// error, don't save cache
@@ -518,10 +503,7 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc
return Result{}, err
}
err = cache.Set(host, result)
if err != nil {
return Result{}, err
}
setCacheResult(cache, host, result)
return result, nil
}
@@ -735,13 +717,13 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
if c != nil {
// initialize objects only once
if gctx.safebrowsingCache == nil {
gctx.safebrowsingCache = gcache.New(c.SafeBrowsingCacheSize).LRU().Expiration(defaultCacheTime).Build()
gctx.safebrowsingCache = fastcache.New(c.SafeBrowsingCacheSize)
}
if gctx.safeSearchCache == nil {
gctx.safeSearchCache = gcache.New(c.SafeSearchCacheSize).LRU().Expiration(defaultCacheTime).Build()
gctx.safeSearchCache = fastcache.New(c.SafeSearchCacheSize)
}
if gctx.parentalCache == nil {
gctx.parentalCache = gcache.New(c.ParentalCacheSize).LRU().Expiration(defaultCacheTime).Build()
gctx.parentalCache = fastcache.New(c.ParentalCacheSize)
}
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()