Use urlfilter format in rebinding allow list

This commit is contained in:
Reinaldo de Souza Jr
2020-12-05 16:49:32 +01:00
parent fcb582679e
commit bad1c6acdc
6 changed files with 66 additions and 27 deletions

View File

@@ -53,6 +53,7 @@ type Server struct {
queryLog querylog.QueryLog // Query log instance
stats stats.Stats
access *accessCtx
rebinding *dnsRebindChecker
ipset ipsetCtx
@@ -222,6 +223,13 @@ func (s *Server) Prepare(config *ServerConfig) error {
return err
}
// Initialize DNS rebinding module
// --
s.rebinding, err = newRebindChecker(s.conf.RebindingAllowedHosts)
if err != nil {
return err
}
// Register web handlers if necessary
// --
if !webRegistered && s.conf.HTTPRegister != nil {

View File

@@ -796,9 +796,9 @@ func TestBlockedDNSRebinding(t *testing.T) {
}
s.conf.RebindingProtectionEnabled = true
s.conf.RebindingAllowedHosts = []string{
"nip.io.",
}
s.rebinding, _ = newRebindChecker([]string{
"||nip.io^",
})
reply, err = dns.Exchange(&req, addr.String())
if err != nil {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)

View File

@@ -315,6 +315,7 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) {
if dc.RebindingAllowedHosts != nil {
s.conf.RebindingAllowedHosts = *dc.RebindingAllowedHosts
restart = true
}
s.Unlock()
s.conf.ConfigModified()

View File

@@ -9,10 +9,41 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/miekg/dns"
)
type dnsRebindChecker struct {
allowDomainEngine *urlfilter.DNSEngine
}
func newRebindChecker(allowedHosts []string) (*dnsRebindChecker, error) {
buf := strings.Builder{}
for _, s := range allowedHosts {
buf.WriteString(s)
buf.WriteString("\n")
}
rulesStorage, err := filterlist.NewRuleStorage([]filterlist.RuleList{
&filterlist.StringRuleList{
ID: int(0),
RulesText: buf.String(),
IgnoreCosmetic: true,
},
})
if err != nil {
return nil, err
}
return &dnsRebindChecker{
allowDomainEngine: urlfilter.NewDNSEngine(rulesStorage),
}, nil
}
func (c *dnsRebindChecker) isAllowedDomain(domain string) bool {
_, ok := c.allowDomainEngine.Match(domain)
return ok
}
// IsPrivate reports whether ip is a private address, according to
@@ -87,14 +118,11 @@ func (s *Server) isResponseRebind(domain, host string) bool {
defer timer.LogElapsed("DNS Rebinding check for %s -> %s", domain, host)
}
for _, h := range s.conf.RebindingAllowedHosts {
if strings.HasSuffix(domain, h) {
return false
}
if s.rebinding.isAllowedDomain(domain) {
return false
}
c := dnsRebindChecker{}
return c.isRebindHost(host)
return s.rebinding.isRebindHost(host)
}
func processRebindingFilteringAfterResponse(ctx *dnsContext) int {
@@ -157,7 +185,7 @@ func (s *Server) preventRebindResponse(ctx *dnsContext) (*dnsfilter.Result, erro
}
log.Debug(m)
blocked := s.isResponseRebind(domainName, host)
blocked := s.isResponseRebind(strings.TrimSuffix(domainName, "."), host)
s.RUnlock()
if blocked {

View File

@@ -9,7 +9,7 @@ import (
)
func TestRebindingPrivateAddresses(t *testing.T) {
c := &dnsRebindChecker{}
c, _ := newRebindChecker(nil)
r1 := byte(rand.Int31() & 0xFE)
r2 := byte(rand.Int31() & 0xFE)
@@ -53,9 +53,11 @@ func TestRebindLocalhost(t *testing.T) {
}
func TestIsResponseRebind(t *testing.T) {
s := &Server{}
s.conf.RebindingAllowedHosts = []string{
"totally-safe.com",
c, _ := newRebindChecker([]string{
"||totally-safe.com^",
})
s := &Server{
rebinding: c,
}
for _, host := range []string{
@@ -84,14 +86,14 @@ func TestIsResponseRebind(t *testing.T) {
"localhost",
} {
s.conf.RebindingProtectionEnabled = true
assert.True(t, s.isResponseRebind("example.com", host))
assert.False(t, s.isResponseRebind("totally-safe.com", host))
assert.False(t, s.isResponseRebind("absolutely.totally-safe.com", host))
assert.Truef(t, s.isResponseRebind("example.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("totally-safe.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("absolutely.totally-safe.com", host), "host: %s", host)
s.conf.RebindingProtectionEnabled = false
assert.False(t, s.isResponseRebind("example.com", host))
assert.False(t, s.isResponseRebind("totally-safe.com", host))
assert.False(t, s.isResponseRebind("absolutely.totally-safe.com", host))
assert.Falsef(t, s.isResponseRebind("example.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("totally-safe.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("absolutely.totally-safe.com", host), "host: %s", host)
}
for _, host := range []string{
@@ -99,13 +101,13 @@ func TestIsResponseRebind(t *testing.T) {
"another-example.com",
} {
s.conf.RebindingProtectionEnabled = true
assert.False(t, s.isResponseRebind("example.com", host))
assert.False(t, s.isResponseRebind("totally-safe.com", host))
assert.False(t, s.isResponseRebind("absolutely.totally-legit.com", host))
assert.Falsef(t, s.isResponseRebind("example.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("totally-safe.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("absolutely.totally-legit.com", host), "host: %s", host)
s.conf.RebindingProtectionEnabled = false
assert.False(t, s.isResponseRebind("example.com", host))
assert.False(t, s.isResponseRebind("totally-safe.com", host))
assert.False(t, s.isResponseRebind("absolutely.totally-legit.com", host))
assert.Falsef(t, s.isResponseRebind("example.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("totally-safe.com", host), "host: %s", host)
assert.Falsef(t, s.isResponseRebind("absolutely.totally-legit.com", host), "host: %s", host)
}
}