Fix many lint warnings found by gometalinter
This commit is contained in:
@@ -45,7 +45,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
type plug struct {
|
||||
d *dnsfilter.Dnsfilter
|
||||
Next plugin.Handler
|
||||
upstream upstream.Upstream
|
||||
@@ -56,12 +56,12 @@ type Plugin struct {
|
||||
QueryLogEnabled bool
|
||||
}
|
||||
|
||||
var defaultPlugin = Plugin{
|
||||
var defaultPlugin = plug{
|
||||
SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com",
|
||||
ParentalBlockHost: "family.block.dns.adguard.com",
|
||||
}
|
||||
|
||||
func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
func newDNSCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "dnsfilter",
|
||||
@@ -71,26 +71,26 @@ func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
}
|
||||
|
||||
var (
|
||||
requests = newDnsCounter("requests_total", "Count of requests seen by dnsfilter.")
|
||||
filtered = newDnsCounter("filtered_total", "Count of requests filtered by dnsfilter.")
|
||||
filteredLists = newDnsCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
|
||||
filteredSafebrowsing = newDnsCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
|
||||
filteredParental = newDnsCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
|
||||
filteredInvalid = newDnsCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
|
||||
whitelisted = newDnsCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
|
||||
safesearch = newDnsCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
|
||||
errorsTotal = newDnsCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
|
||||
requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.")
|
||||
filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.")
|
||||
filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
|
||||
filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
|
||||
filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
|
||||
filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
|
||||
whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
|
||||
safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
|
||||
errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
|
||||
)
|
||||
|
||||
//
|
||||
// coredns handling functions
|
||||
//
|
||||
func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
func setupPlugin(c *caddy.Controller) (*plug, error) {
|
||||
// create new Plugin and copy default values
|
||||
var d = new(Plugin)
|
||||
*d = defaultPlugin
|
||||
d.d = dnsfilter.New()
|
||||
d.hosts = make(map[string]net.IP)
|
||||
var p = new(plug)
|
||||
*p = defaultPlugin
|
||||
p.d = dnsfilter.New()
|
||||
p.hosts = make(map[string]net.IP)
|
||||
|
||||
var filterFileName string
|
||||
for c.Next() {
|
||||
@@ -103,15 +103,15 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "safebrowsing":
|
||||
d.d.EnableSafeBrowsing()
|
||||
p.d.EnableSafeBrowsing()
|
||||
if c.NextArg() {
|
||||
if len(c.Val()) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
d.d.SetSafeBrowsingServer(c.Val())
|
||||
p.d.SetSafeBrowsingServer(c.Val())
|
||||
}
|
||||
case "safesearch":
|
||||
d.d.EnableSafeSearch()
|
||||
p.d.EnableSafeSearch()
|
||||
case "parental":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
@@ -120,7 +120,7 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
err = d.d.EnableParental(sensitivity)
|
||||
err = p.d.EnableParental(sensitivity)
|
||||
if err != nil {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
@@ -128,10 +128,10 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
if len(c.Val()) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
d.ParentalBlockHost = c.Val()
|
||||
p.ParentalBlockHost = c.Val()
|
||||
}
|
||||
case "querylog":
|
||||
d.QueryLogEnabled = true
|
||||
p.QueryLogEnabled = true
|
||||
onceQueryLog.Do(func() {
|
||||
go startQueryLogServer() // TODO: how to handle errors?
|
||||
})
|
||||
@@ -149,10 +149,10 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if d.parseEtcHosts(text) {
|
||||
if p.parseEtcHosts(text) {
|
||||
continue
|
||||
}
|
||||
err = d.d.AddRule(text, 0)
|
||||
err = p.d.AddRule(text, 0)
|
||||
if err == dnsfilter.ErrInvalidSyntax {
|
||||
continue
|
||||
}
|
||||
@@ -167,23 +167,23 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.upstream, err = upstream.New(nil)
|
||||
p.upstream, err = upstream.New(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d, nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
d, err := setupPlugin(c)
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
d.Next = next
|
||||
return d
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
@@ -200,16 +200,16 @@ func setup(c *caddy.Controller) error {
|
||||
x.MustRegister(whitelisted)
|
||||
x.MustRegister(safesearch)
|
||||
x.MustRegister(errorsTotal)
|
||||
x.MustRegister(d)
|
||||
x.MustRegister(p)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
c.OnShutdown(d.OnShutdown)
|
||||
c.OnShutdown(p.onShutdown)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Plugin) parseEtcHosts(text string) bool {
|
||||
func (p *plug) parseEtcHosts(text string) bool {
|
||||
if pos := strings.IndexByte(text, '#'); pos != -1 {
|
||||
text = text[0:pos]
|
||||
}
|
||||
@@ -222,17 +222,17 @@ func (d *Plugin) parseEtcHosts(text string) bool {
|
||||
return false
|
||||
}
|
||||
for _, host := range fields[1:] {
|
||||
if val, ok := d.hosts[host]; ok {
|
||||
if val, ok := p.hosts[host]; ok {
|
||||
log.Printf("warning: host %s already has value %s, will overwrite it with %s", host, val, addr)
|
||||
}
|
||||
d.hosts[host] = addr
|
||||
p.hosts[host] = addr
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *Plugin) OnShutdown() error {
|
||||
d.d.Destroy()
|
||||
d.d = nil
|
||||
func (p *plug) onShutdown() error {
|
||||
p.d.Destroy()
|
||||
p.d = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ type statsFunc func(ch interface{}, name string, text string, value float64, val
|
||||
|
||||
func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
|
||||
realch, ok := ch.(chan<- *prometheus.Desc)
|
||||
if ok == false {
|
||||
if !ok {
|
||||
log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n")
|
||||
return
|
||||
}
|
||||
@@ -249,7 +249,7 @@ func doDesc(ch interface{}, name string, text string, value float64, valueType p
|
||||
|
||||
func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) {
|
||||
realch, ok := ch.(chan<- prometheus.Metric)
|
||||
if ok == false {
|
||||
if !ok {
|
||||
log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n")
|
||||
return
|
||||
}
|
||||
@@ -268,21 +268,23 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d
|
||||
gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue)
|
||||
}
|
||||
|
||||
func (d *Plugin) doStats(ch interface{}, doFunc statsFunc) {
|
||||
stats := d.d.GetStats()
|
||||
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
|
||||
stats := p.d.GetStats()
|
||||
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
|
||||
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
|
||||
}
|
||||
|
||||
func (d *Plugin) Describe(ch chan<- *prometheus.Desc) {
|
||||
d.doStats(ch, doDesc)
|
||||
// Describe is called by prometheus handler to know stat types
|
||||
func (p *plug) Describe(ch chan<- *prometheus.Desc) {
|
||||
p.doStats(ch, doDesc)
|
||||
}
|
||||
|
||||
func (d *Plugin) Collect(ch chan<- prometheus.Metric) {
|
||||
d.doStats(ch, doMetric)
|
||||
// Collect is called by prometheus handler to collect stats
|
||||
func (p *plug) Collect(ch chan<- prometheus.Metric) {
|
||||
p.doStats(ch, doMetric)
|
||||
}
|
||||
|
||||
func (d *Plugin) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) {
|
||||
func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) {
|
||||
// check if it's a domain name or IP address
|
||||
addr := net.ParseIP(val)
|
||||
var records []dns.RR
|
||||
@@ -301,7 +303,7 @@ func (d *Plugin) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseW
|
||||
req.SetQuestion(dns.Fqdn(val), question.Qtype)
|
||||
req.RecursionDesired = true
|
||||
reqstate := request.Request{W: w, Req: req, Context: ctx}
|
||||
result, err := d.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType())
|
||||
result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType())
|
||||
if err != nil {
|
||||
log.Printf("Got error %s\n", err)
|
||||
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
@@ -363,80 +365,80 @@ func writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int,
|
||||
return dns.RcodeNameError, nil
|
||||
}
|
||||
|
||||
func (d *Plugin) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error, dnsfilter.Result) {
|
||||
func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) {
|
||||
if len(r.Question) != 1 {
|
||||
// google DNS, bind and others do the same
|
||||
return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions"), dnsfilter.Result{}
|
||||
return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("Got DNS request with != 1 questions")
|
||||
}
|
||||
for _, question := range r.Question {
|
||||
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
||||
// is it a safesearch domain?
|
||||
if val, ok := d.d.SafeSearchDomain(host); ok {
|
||||
rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if val, ok := p.d.SafeSearchDomain(host); ok {
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}
|
||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
||||
}
|
||||
|
||||
// is it in hosts?
|
||||
if val, ok := d.hosts[host]; ok {
|
||||
if val, ok := p.hosts[host]; ok {
|
||||
// it is, if it's a loopback host, reply with NXDOMAIN
|
||||
if val.IsLoopback() {
|
||||
rcode, err := writeNXdomain(ctx, w, r)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredInvalid}
|
||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredInvalid}, err
|
||||
}
|
||||
// it's not a loopback host, replace it with value specified
|
||||
rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val.String(), question)
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val.String(), question)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}
|
||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
||||
}
|
||||
|
||||
// needs to be filtered instead
|
||||
result, err := d.d.CheckHost(host)
|
||||
result, err := p.d.CheckHost(host)
|
||||
if err != nil {
|
||||
log.Printf("plugin/dnsfilter: %s\n", err)
|
||||
return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err), dnsfilter.Result{}
|
||||
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||
}
|
||||
|
||||
if result.IsFiltered {
|
||||
switch result.Reason {
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
// return cname safebrowsing.block.dns.adguard.com
|
||||
val := d.SafeBrowsingBlockHost
|
||||
rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
val := p.SafeBrowsingBlockHost
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, result
|
||||
return rcode, result, err
|
||||
case dnsfilter.FilteredParental:
|
||||
// return cname family.block.dns.adguard.com
|
||||
val := d.ParentalBlockHost
|
||||
rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
val := p.ParentalBlockHost
|
||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, result
|
||||
return rcode, result, err
|
||||
case dnsfilter.FilteredBlackList:
|
||||
// return NXdomain
|
||||
rcode, err := writeNXdomain(ctx, w, r)
|
||||
if err != nil {
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
return rcode, err, result
|
||||
return rcode, result, err
|
||||
default:
|
||||
log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering: %T %v %s", result.Reason, result.Reason, result.Reason.String())
|
||||
}
|
||||
} else {
|
||||
switch result.Reason {
|
||||
case dnsfilter.NotFilteredWhiteList:
|
||||
rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r)
|
||||
return rcode, err, result
|
||||
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
return rcode, result, err
|
||||
case dnsfilter.NotFilteredNotFound:
|
||||
// do nothing, pass through to lower code
|
||||
default:
|
||||
@@ -444,11 +446,12 @@ func (d *Plugin) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *
|
||||
}
|
||||
}
|
||||
}
|
||||
rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r)
|
||||
return rcode, err, dnsfilter.Result{}
|
||||
rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
return rcode, dnsfilter.Result{}, err
|
||||
}
|
||||
|
||||
func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
// ServeDNS handles the DNS request and refuses if it's in filterlists
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
start := time.Now()
|
||||
requests.Inc()
|
||||
state := request.Request{W: w, Req: r}
|
||||
@@ -456,13 +459,16 @@ func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||
|
||||
// capture the written answer
|
||||
rrw := dnstest.NewRecorder(w)
|
||||
rcode, err, result := d.serveDNSInternal(ctx, rrw, r)
|
||||
rcode, result, err := p.serveDNSInternal(ctx, rrw, r)
|
||||
if rcode > 0 {
|
||||
// actually send the answer if we have one
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rcode)
|
||||
state.SizeAndDo(answer)
|
||||
w.WriteMsg(answer)
|
||||
err = w.WriteMsg(answer)
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
}
|
||||
|
||||
// increment counters
|
||||
@@ -496,12 +502,13 @@ func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||
}
|
||||
|
||||
// log
|
||||
if d.QueryLogEnabled {
|
||||
if p.QueryLogEnabled {
|
||||
logRequest(r, rrw.Msg, result, time.Since(start), ip)
|
||||
}
|
||||
return rcode, err
|
||||
}
|
||||
|
||||
func (d *Plugin) Name() string { return "dnsfilter" }
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "dnsfilter" }
|
||||
|
||||
var onceQueryLog sync.Once
|
||||
|
||||
@@ -46,10 +46,10 @@ func TestEtcHostsParse(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := tmpfile.Write(text); err != nil {
|
||||
if _, err = tmpfile.Write(text); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
if err = tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -80,10 +80,10 @@ func TestEtcHostsFilter(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := tmpfile.Write(text); err != nil {
|
||||
if _, err = tmpfile.Write(text); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
if err = tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -127,10 +127,10 @@ func TestEtcHostsFilter(t *testing.T) {
|
||||
t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode)
|
||||
}
|
||||
filtered := rcode == dns.RcodeNameError
|
||||
if testcase.filtered == true && testcase.filtered != filtered {
|
||||
if testcase.filtered && testcase.filtered != filtered {
|
||||
t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host)
|
||||
}
|
||||
if testcase.filtered == false && testcase.filtered != filtered {
|
||||
if !testcase.filtered && testcase.filtered != filtered {
|
||||
t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// ratelimiting and per-ip buckets
|
||||
@@ -29,8 +28,8 @@ var (
|
||||
tokenBuckets = cache.New(time.Hour, time.Hour)
|
||||
)
|
||||
|
||||
// main function
|
||||
func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
ip := state.IP()
|
||||
allow, err := p.allowRequest(ip)
|
||||
@@ -44,7 +43,7 @@ func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
func (p *Plugin) allowRequest(ip string) (bool, error) {
|
||||
func (p *plug) allowRequest(ip string) (bool, error) {
|
||||
if _, found := tokenBuckets.Get(ip); !found {
|
||||
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
||||
}
|
||||
@@ -59,7 +58,7 @@ func (p *Plugin) allowRequest(ip string) (bool, error) {
|
||||
}
|
||||
|
||||
rl, ok := value.(*rate.RateLimiter)
|
||||
if ok == false {
|
||||
if !ok {
|
||||
text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache"
|
||||
log.Println(text)
|
||||
err := errors.New(text)
|
||||
@@ -80,7 +79,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
type plug struct {
|
||||
Next plugin.Handler
|
||||
|
||||
// configuration for creating above
|
||||
@@ -88,7 +87,7 @@ type Plugin struct {
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p := &Plugin{ratelimit: defaultRatelimit}
|
||||
p := &plug{ratelimit: defaultRatelimit}
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
for c.Next() {
|
||||
@@ -109,22 +108,20 @@ func setup(c *caddy.Controller) error {
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
once.Do(func() {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
})
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
func newDNSCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "ratelimit",
|
||||
@@ -134,9 +131,8 @@ func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
}
|
||||
|
||||
var (
|
||||
ratelimited = newDnsCounter("dropped_total", "Count of requests that have been dropped because of rate limit")
|
||||
ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit")
|
||||
)
|
||||
|
||||
func (d *Plugin) Name() string { return "ratelimit" }
|
||||
|
||||
var once sync.Once
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "ratelimit" }
|
||||
|
||||
@@ -3,7 +3,6 @@ package refuseany
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
@@ -15,11 +14,12 @@ import (
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type Plugin struct {
|
||||
type plug struct {
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
// ServeDNS handles the DNS request and refuses if it's an ANY request
|
||||
func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
if len(r.Question) != 1 {
|
||||
// google DNS, bind and others do the same
|
||||
return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions")
|
||||
@@ -41,9 +41,9 @@ func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
return rcode, nil
|
||||
} else {
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -54,7 +54,7 @@ func init() {
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p := &Plugin{}
|
||||
p := &plug{}
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
@@ -63,22 +63,20 @@ func setup(c *caddy.Controller) error {
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
once.Do(func() {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
})
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
func newDNSCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "refuseany",
|
||||
@@ -88,9 +86,8 @@ func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
}
|
||||
|
||||
var (
|
||||
ratelimited = newDnsCounter("refusedany_total", "Count of ANY requests that have been dropped")
|
||||
ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped")
|
||||
)
|
||||
|
||||
func (d *Plugin) Name() string { return "refuseany" }
|
||||
|
||||
var once sync.Once
|
||||
// Name returns name of the plugin as seen in Corefile and plugin.cfg
|
||||
func (p *plug) Name() string { return "refuseany" }
|
||||
|
||||
Reference in New Issue
Block a user