Initial commit
This commit is contained in:
142
coredns_plugin/ratelimit/ratelimit.go
Normal file
142
coredns_plugin/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// ratelimiting and per-ip buckets
|
||||
"github.com/beefsack/go-rate"
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
||||
// coredns plugin
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const defaultRatelimit = 100
|
||||
const defaultMaxRateLimitedIPs = 1024 * 1024
|
||||
|
||||
var (
|
||||
tokenBuckets = cache.New(time.Hour, time.Hour)
|
||||
)
|
||||
|
||||
// main function
|
||||
func (p *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
ip := state.IP()
|
||||
allow, err := p.allowRequest(ip)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !allow {
|
||||
ratelimited.Inc()
|
||||
return 0, nil
|
||||
}
|
||||
return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
func (p *Plugin) allowRequest(ip string) (bool, error) {
|
||||
if _, found := tokenBuckets.Get(ip); !found {
|
||||
tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour)
|
||||
}
|
||||
|
||||
value, found := tokenBuckets.Get(ip)
|
||||
if !found {
|
||||
// should not happen since we've just inserted it
|
||||
text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared"
|
||||
log.Println(text)
|
||||
err := errors.New(text)
|
||||
return true, err
|
||||
}
|
||||
|
||||
rl, ok := value.(*rate.RateLimiter)
|
||||
if ok == false {
|
||||
text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache"
|
||||
log.Println(text)
|
||||
err := errors.New(text)
|
||||
return true, err
|
||||
}
|
||||
|
||||
allow, _ := rl.Try()
|
||||
return allow, nil
|
||||
}
|
||||
|
||||
//
|
||||
// helper functions
|
||||
//
|
||||
func init() {
|
||||
caddy.RegisterPlugin("ratelimit", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
Next plugin.Handler
|
||||
|
||||
// configuration for creating above
|
||||
ratelimit int // in requests per second per IP
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
p := &Plugin{ratelimit: defaultRatelimit}
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) <= 0 {
|
||||
continue
|
||||
}
|
||||
ratelimit, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
}
|
||||
p.ratelimit = ratelimit
|
||||
}
|
||||
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
once.Do(func() {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(ratelimited)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDnsCounter(name string, help string) prometheus.Counter {
|
||||
return prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "ratelimit",
|
||||
Name: name,
|
||||
Help: help,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
ratelimited = newDnsCounter("dropped_total", "Count of requests that have been dropped because of rate limit")
|
||||
)
|
||||
|
||||
func (d *Plugin) Name() string { return "ratelimit" }
|
||||
|
||||
var once sync.Once
|
||||
Reference in New Issue
Block a user