Merge branch 'master' into fix/576
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -55,6 +57,7 @@ func NewServer(baseDir string) *Server {
|
||||
}
|
||||
|
||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||
// The zero FilteringConfig is empty and ready for use.
|
||||
type FilteringConfig struct {
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
@@ -68,6 +71,13 @@ type FilteringConfig struct {
|
||||
dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"`
|
||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain
|
||||
PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key
|
||||
}
|
||||
|
||||
// ServerConfig represents server configuration.
|
||||
// The zero ServerConfig is empty and ready for use.
|
||||
type ServerConfig struct {
|
||||
@@ -77,6 +87,7 @@ type ServerConfig struct {
|
||||
Filters []dnsfilter.Filter // A list of filters to use
|
||||
|
||||
FilteringConfig
|
||||
TLSConfig
|
||||
}
|
||||
|
||||
// if any of ServerConfig values are zero, then default values from below are used
|
||||
@@ -91,7 +102,7 @@ func init() {
|
||||
|
||||
defaultUpstreams := make([]upstream.Upstream, 0)
|
||||
for _, addr := range defaultDNS {
|
||||
u, err := upstream.AddressToUpstream(addr, "", DefaultTimeout)
|
||||
u, err := upstream.AddressToUpstream(addr, upstream.Options{Timeout: DefaultTimeout})
|
||||
if err == nil {
|
||||
defaultUpstreams = append(defaultUpstreams, u)
|
||||
}
|
||||
@@ -154,6 +165,15 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
||||
Handler: s.handleDNSRequest,
|
||||
}
|
||||
|
||||
if s.TLSListenAddr != nil && s.CertificateChain != "" && s.PrivateKey != "" {
|
||||
proxyConfig.TLSListenAddr = s.TLSListenAddr
|
||||
keypair, err := tls.X509KeyPair([]byte(s.CertificateChain), []byte(s.PrivateKey))
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Failed to parse TLS keypair")
|
||||
}
|
||||
proxyConfig.TLSConfig = &tls.Config{Certificates: []tls.Certificate{keypair}}
|
||||
}
|
||||
|
||||
if proxyConfig.UDPListenAddr == nil {
|
||||
proxyConfig.UDPListenAddr = defaultValues.UDPListenAddr
|
||||
}
|
||||
@@ -240,24 +260,38 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeHTTP is a HTTP handler method we use to provide DNS-over-HTTPS
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.RLock()
|
||||
s.dnsProxy.ServeHTTP(w, r)
|
||||
s.RUnlock()
|
||||
}
|
||||
|
||||
// GetQueryLog returns a map with the current query log ready to be converted to a JSON
|
||||
func (s *Server) GetQueryLog() []map[string]interface{} {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.queryLog.getQueryLog()
|
||||
}
|
||||
|
||||
// GetStatsTop returns the current stop stats
|
||||
func (s *Server) GetStatsTop() *StatsTop {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.queryLog.runningTop.getStatsTop()
|
||||
}
|
||||
|
||||
// PurgeStats purges current server stats
|
||||
func (s *Server) PurgeStats() {
|
||||
// TODO: Locks?
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.stats.purgeStats()
|
||||
}
|
||||
|
||||
// GetAggregatedStats returns aggregated stats data for the 24 hours
|
||||
func (s *Server) GetAggregatedStats() map[string]interface{} {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.stats.getAggregatedStats()
|
||||
}
|
||||
|
||||
@@ -267,6 +301,8 @@ func (s *Server) GetAggregatedStats() map[string]interface{} {
|
||||
// end is end of the time range
|
||||
// returns nil if time unit is not supported
|
||||
func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.stats.getStatsHistory(timeUnit, startTime, endTime)
|
||||
}
|
||||
|
||||
@@ -350,9 +386,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
|
||||
|
||||
switch result.Reason {
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
return s.genBlockedHost(m, safeBrowsingBlockHost, d.Upstream)
|
||||
return s.genBlockedHost(m, safeBrowsingBlockHost, d)
|
||||
case dnsfilter.FilteredParental:
|
||||
return s.genBlockedHost(m, parentalBlockHost, d.Upstream)
|
||||
return s.genBlockedHost(m, parentalBlockHost, d)
|
||||
default:
|
||||
if result.IP != nil {
|
||||
return s.genARecord(m, result.IP)
|
||||
@@ -381,22 +417,30 @@ func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
return &resp
|
||||
}
|
||||
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, upstream upstream.Upstream) *dns.Msg {
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||
// look up the hostname, TODO: cache
|
||||
replReq := dns.Msg{}
|
||||
replReq.SetQuestion(dns.Fqdn(newAddr), request.Question[0].Qtype)
|
||||
replReq.RecursionDesired = true
|
||||
reply, err := upstream.Exchange(&replReq)
|
||||
|
||||
newContext := &proxy.DNSContext{
|
||||
Proto: d.Proto,
|
||||
Addr: d.Addr,
|
||||
StartTime: time.Now(),
|
||||
Req: &replReq,
|
||||
}
|
||||
|
||||
err := s.dnsProxy.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't look up replacement host '%s' on upstream %s: %s", newAddr, upstream.Address(), err)
|
||||
log.Printf("Couldn't look up replacement host '%s': %s", newAddr, err)
|
||||
return s.genServerFailure(request)
|
||||
}
|
||||
|
||||
resp := dns.Msg{}
|
||||
resp.SetReply(request)
|
||||
resp.Authoritative, resp.RecursionAvailable = true, true
|
||||
if reply != nil {
|
||||
for _, answer := range reply.Answer {
|
||||
if newContext.Res != nil {
|
||||
for _, answer := range newContext.Res.Answer {
|
||||
answer.Header().Name = request.Question[0].Name
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
|
||||
@@ -1,17 +1,34 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
tlsServerName = "testdns.adguard.com"
|
||||
dataDir = "testData"
|
||||
testMessagesCount = 10
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
@@ -22,7 +39,7 @@ func TestServer(t *testing.T) {
|
||||
|
||||
// message over UDP
|
||||
req := createGoogleATestMessage()
|
||||
addr := s.dnsProxy.Addr("udp")
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
client := dns.Client{Net: "udp"}
|
||||
reply, _, err := client.Exchange(req, addr.String())
|
||||
if err != nil {
|
||||
@@ -63,6 +80,69 @@ func TestServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDotServer(t *testing.T) {
|
||||
// Prepare the proxy server
|
||||
_, certPem, keyPem := createServerTLSConfig(t)
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
|
||||
s.TLSConfig = TLSConfig{
|
||||
TLSListenAddr: &net.TCPAddr{Port: 0},
|
||||
CertificateChain: string(certPem),
|
||||
PrivateKey: string(keyPem),
|
||||
}
|
||||
|
||||
// Starting the server
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
|
||||
// Add our self-signed generated config to roots
|
||||
roots := x509.NewCertPool()
|
||||
roots.AppendCertsFromPEM(certPem)
|
||||
tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}
|
||||
|
||||
// Create a DNS-over-TLS client connection
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoTLS)
|
||||
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot connect to the proxy: %s", err)
|
||||
}
|
||||
|
||||
sendTestMessages(t, conn)
|
||||
|
||||
// Stop the proxy
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerRace(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
defer removeDataDir(t)
|
||||
err := s.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
|
||||
// message over UDP
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
conn, err := dns.Dial("udp", addr.String())
|
||||
if err != nil {
|
||||
t.Fatalf("cannot connect to the proxy: %s", err)
|
||||
}
|
||||
|
||||
sendTestMessagesAsync(t, conn)
|
||||
|
||||
// Stop the proxy
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("DNS server failed to stop: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
s.SafeSearchEnabled = true
|
||||
@@ -141,7 +221,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
// server is running, send a message
|
||||
addr := s.dnsProxy.Addr("udp")
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
@@ -175,7 +255,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
addr := s.dnsProxy.Addr("udp")
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
//
|
||||
// NXDomain blocking
|
||||
@@ -216,7 +296,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
addr := s.dnsProxy.Addr("udp")
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
//
|
||||
// Hosts blocking
|
||||
@@ -264,7 +344,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %s", err)
|
||||
}
|
||||
addr := s.dnsProxy.Addr("udp")
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
|
||||
//
|
||||
// Safebrowsing blocking
|
||||
@@ -320,6 +400,7 @@ func createTestServer(t *testing.T) *Server {
|
||||
s := NewServer(createDataDir(t))
|
||||
s.UDPListenAddr = &net.UDPAddr{Port: 0}
|
||||
s.TCPListenAddr = &net.TCPAddr{Port: 0}
|
||||
|
||||
s.QueryLogEnabled = true
|
||||
s.FilteringConfig.FilteringEnabled = true
|
||||
s.FilteringConfig.ProtectionEnabled = true
|
||||
@@ -335,20 +416,111 @@ func createTestServer(t *testing.T) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
func createDataDir(t *testing.T) string {
|
||||
dir := "testData"
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot create %s: %s", dir, err)
|
||||
t.Fatalf("cannot generate RSA key: %s", err)
|
||||
}
|
||||
return dir
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate serial number: %s", err)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"AdGuard Tests"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
template.DNSNames = append(template.DNSNames, tlsServerName)
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %s", err)
|
||||
}
|
||||
|
||||
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %s", err)
|
||||
}
|
||||
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem, keyPem
|
||||
}
|
||||
|
||||
func createDataDir(t *testing.T) string {
|
||||
err := os.MkdirAll(dataDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot create %s: %s", dataDir, err)
|
||||
}
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func removeDataDir(t *testing.T) {
|
||||
dir := "testData"
|
||||
err := os.RemoveAll(dir)
|
||||
err := os.RemoveAll(dataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot remove %s: %s", dir, err)
|
||||
t.Fatalf("Cannot remove %s: %s", dataDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
|
||||
defer func() {
|
||||
g.Done()
|
||||
}()
|
||||
|
||||
req := createTestMessage()
|
||||
err := conn.WriteMsg(req)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot write message: %s", err)
|
||||
}
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot read response to message: %s", err)
|
||||
}
|
||||
assertResponse(t, res)
|
||||
}
|
||||
|
||||
// sendTestMessagesAsync sends messages in parallel
|
||||
// so that we could find race issues
|
||||
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
|
||||
g := &sync.WaitGroup{}
|
||||
g.Add(testMessagesCount)
|
||||
|
||||
for i := 0; i < testMessagesCount; i++ {
|
||||
go sendTestMessageAsync(t, conn, g)
|
||||
}
|
||||
|
||||
g.Wait()
|
||||
}
|
||||
|
||||
func sendTestMessages(t *testing.T, conn *dns.Conn) {
|
||||
for i := 0; i < 10; i++ {
|
||||
req := createTestMessage()
|
||||
err := conn.WriteMsg(req)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot write message #%d: %s", i, err)
|
||||
}
|
||||
|
||||
res, err := conn.ReadMsg()
|
||||
if err != nil {
|
||||
t.Fatalf("cannot read response to message #%d: %s", i, err)
|
||||
}
|
||||
assertResponse(t, res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,3 +563,14 @@ func assertResponse(t *testing.T, reply *dns.Msg, ip string) {
|
||||
t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0])
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user