Fix #579
1. Added --workdir command-line argument that lets configure the working dir. 2. Made "dnsforward" use this workdir parameter when saving/reading querylog. 3. Reworked "dnsforward" -- moved http handlers out of there to control.go
This commit is contained in:
@@ -2,18 +2,19 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := Server{}
|
||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
@@ -29,6 +30,14 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
assertResponse(t, reply)
|
||||
|
||||
// check query log and stats
|
||||
log := s.GetQueryLog()
|
||||
assert.Equal(t, 1, len(log), "Log size")
|
||||
stats := s.GetStatsTop()
|
||||
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
|
||||
|
||||
// message over TCP
|
||||
req = createTestMessage()
|
||||
addr = s.dnsProxy.Addr("tcp")
|
||||
@@ -39,6 +48,15 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
assertResponse(t, reply)
|
||||
|
||||
// check query log and stats again
|
||||
log = s.GetQueryLog()
|
||||
assert.Equal(t, 2, len(log), "Log size")
|
||||
stats = s.GetStatsTop()
|
||||
// Length did not change as we queried the same domain
|
||||
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
@@ -46,9 +64,8 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInvalidRequest(t *testing.T) {
|
||||
s := Server{}
|
||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
@@ -67,6 +84,15 @@ func TestInvalidRequest(t *testing.T) {
|
||||
t.Fatalf("got a response to an invalid query")
|
||||
}
|
||||
|
||||
// check query log and stats
|
||||
// invalid requests aren't written to the query log
|
||||
log := s.GetQueryLog()
|
||||
assert.Equal(t, 0, len(log), "Log size")
|
||||
stats := s.GetStatsTop()
|
||||
assert.Equal(t, 0, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 0, len(stats.Clients), "Top clients length")
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
@@ -74,7 +100,8 @@ func TestInvalidRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBlockedRequest(t *testing.T) {
|
||||
s := createTestServer()
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
@@ -99,6 +126,14 @@ func TestBlockedRequest(t *testing.T) {
|
||||
t.Fatalf("Wrong response: %s", reply.String())
|
||||
}
|
||||
|
||||
// check query log and stats
|
||||
log := s.GetQueryLog()
|
||||
assert.Equal(t, 1, len(log), "Log size")
|
||||
stats := s.GetStatsTop()
|
||||
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
@@ -106,7 +141,8 @@ func TestBlockedRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBlockedByHosts(t *testing.T) {
|
||||
s := createTestServer()
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
@@ -138,6 +174,14 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
}
|
||||
|
||||
// check query log and stats
|
||||
log := s.GetQueryLog()
|
||||
assert.Equal(t, 1, len(log), "Log size")
|
||||
stats := s.GetStatsTop()
|
||||
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
@@ -145,7 +189,8 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
s := createTestServer()
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
@@ -188,16 +233,25 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
|
||||
}
|
||||
|
||||
// check query log and stats
|
||||
log := s.GetQueryLog()
|
||||
assert.Equal(t, 1, len(log), "Log size")
|
||||
stats := s.GetStatsTop()
|
||||
assert.Equal(t, 1, len(stats.Domains), "Top domains length")
|
||||
assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
|
||||
assert.Equal(t, 1, len(stats.Clients), "Top clients length")
|
||||
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestServer() *Server {
|
||||
s := Server{}
|
||||
func createTestServer(t *testing.T) *Server {
|
||||
s := NewServer(createDataDir(t))
|
||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||
s.QueryLogEnabled = true
|
||||
s.FilteringConfig.FilteringEnabled = true
|
||||
s.FilteringConfig.ProtectionEnabled = true
|
||||
s.FilteringConfig.SafeBrowsingEnabled = true
|
||||
@@ -209,7 +263,24 @@ func createTestServer() *Server {
|
||||
}
|
||||
filter := dnsfilter.Filter{ID: 1, Rules: rules}
|
||||
s.Filters = append(s.Filters, filter)
|
||||
return &s
|
||||
return s
|
||||
}
|
||||
|
||||
func createDataDir(t *testing.T) string {
|
||||
dir := "testData"
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot create %s: %s", dir, err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func removeDataDir(t *testing.T) {
|
||||
dir := "testData"
|
||||
err := os.RemoveAll(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot remove %s: %s", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestMessage() *dns.Msg {
|
||||
|
||||
Reference in New Issue
Block a user