Merge: * use upstream servers directly for the internal DNS resolver

Close #1212

* Server.Start(config *ServerConfig) -> Start()
+ Server.Prepare(config *ServerConfig)
+ Server.Resolve(host string)
+ Server.Exchange()
* rDNS: use internal DNS resolver
- clients: fix race in WriteDiskConfig()
- fix race: move 'clients' object from 'configuration' to 'HomeContext'
    Go race detector didn't like our 'clients' object in 'configuration'.
+ add AGH startup test
    . Create a configuration file
    . Start AGH instance
    . Check Web server
    . Check DNS server
    . Wait until the filters are downloaded
    . Stop and cleanup
* move module objects from config.* to Context.*
* don't call log.SetLevel() if not necessary
    This helps to avoid Go race detector's warning
* ci.sh: 'make' and then run tests

Squashed commit of the following:

commit 86500c7f749307f37af4cc8c2a1066f679d0cfad
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 18:08:53 2019 +0300

    minor

commit 6e6abb9dca3cd250c458bec23aa30d2250a9eb40
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 18:08:31 2019 +0300

    * ci.sh: 'make' and then run tests

commit 114192eefea6800e565ba9ab238202c006516c27
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 17:50:04 2019 +0300

    fix

commit d426deea7f02cdfd4c7217a38c59e51251956a0f
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 17:46:33 2019 +0300

    tests

commit 7b350edf03027895b4e43dee908d0155a9b0ac9b
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:56:12 2019 +0300

    fix test

commit 2f5f116873bbbfdd4bb7f82a596f9e1f5c2bcfd8
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:48:56 2019 +0300

    fix tests

commit 3fbdc77f9c34726e2295185279444983652d559e
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:45:00 2019 +0300

    linter

commit 9da0b6965a2b6863bcd552fa83a4de2866600bb8
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:33:23 2019 +0300

    * config.dnsctx.whois -> Context.whois

commit c71ebdbdf6efd88c877b2f243c69d3bc00a997d7
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:31:08 2019 +0300

    * don't call log.SetLevel() if not necessary

    This helps to avoid Go race detector's warning

commit 0f250220133cefdcb0843a50000cb932802b8324
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 15:28:19 2019 +0300

    * rdns: refactor

commit c460d8c9414940dac852e390b6c1b4d4fb38dff9
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 14:08:08 2019 +0300

    Revert: * stats: serialize access to 'limit'

    Use 'conf *Config' and update it atomically, as in querylog module.
    (Note: Race detector still doesn't like it)

commit 488bcb884971276de0d5629384b29e22c59ee7e6
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 13:50:23 2019 +0300

    * config.dnsFilter -> Context.dnsFilter

commit 86c0a6827a450414b50acec7ebfc5220d13b81e4
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 13:45:05 2019 +0300

    * config.dnsServer -> Context.dnsServer

commit ee35ef095ccaabc89e3de0ef52c9b5ed56b36873
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 13:42:10 2019 +0300

    * config.dhcpServer -> Context.dhcpServer

commit 1537001cd211099d5fad01696c0b806ae5d257b1
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 13:39:45 2019 +0300

    * config.queryLog -> Context.queryLog

commit e5955fe4ff1ef6f41763461b37b502ea25a3d04c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Dec 10 13:03:18 2019 +0300

    * config.httpsServer -> Context.httpsServer

commit 6153c10a9ac173e159d1f05e0db1512579b9203c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Dec 9 20:12:24 2019 +0300

    * config.httpServer -> Context.httpServer

commit abd021fb94039015cd45c97614e8b78d4694f956
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Dec 9 20:08:05 2019 +0300

    * stats: serialize access to 'limit'

commit 38c2decfd87c712100edcabe62a6d4518719cb53
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Dec 9 19:57:04 2019 +0300

    * config.stats -> Context.stats

commit 6caf8965ad44db9dce9a7a5103aa8fa305ad9a06
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Dec 9 19:45:23 2019 +0300

    fix Restart()

... and 6 more commits
This commit is contained in:
Simon Zolin
2019-12-11 12:38:58 +03:00
parent fe357d04f7
commit 0a66913b4d
23 changed files with 439 additions and 251 deletions

View File

@@ -128,24 +128,30 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
// WriteDiskConfig - write configuration
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
clientsList := clients.GetList()
for _, cli := range clientsList {
clients.lock.Lock()
for _, cli := range clients.list {
cy := clientObject{
Name: cli.Name,
IDs: cli.IDs,
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchEnabled: cli.SafeSearchEnabled,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
Name: cli.Name,
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchEnabled: cli.SafeSearchEnabled,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
BlockedServices: cli.BlockedServices,
Upstreams: cli.Upstreams,
}
cy.IDs = make([]string, len(cli.IDs))
copy(cy.IDs, cli.IDs)
cy.BlockedServices = make([]string, len(cli.BlockedServices))
copy(cy.BlockedServices, cli.BlockedServices)
cy.Upstreams = make([]string, len(cli.Upstreams))
copy(cy.Upstreams, cli.Upstreams)
*objects = append(*objects, cy)
}
clients.lock.Unlock()
}
func (clients *clientsContainer) periodicUpdate() {
@@ -157,11 +163,6 @@ func (clients *clientsContainer) periodicUpdate() {
}
}
// GetList returns the pointer to clients list
func (clients *clientsContainer) GetList() map[string]*Client {
return clients.list
}
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
clients.lock.Lock()

View File

@@ -29,6 +29,7 @@ type logSettings struct {
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
}
// HTTPSServer - HTTPS Server
type HTTPSServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
@@ -51,25 +52,15 @@ type configuration struct {
runningAsService bool
disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal
clients clientsContainer // per-client-settings module
controlLock sync.Mutex
transport *http.Transport
client *http.Client
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
auth *Auth // HTTP authentication module
auth *Auth // HTTP authentication module
// cached version.json to avoid hammering github.io for each page reload
versionCheckJSON []byte
versionCheckLastTime time.Time
dnsctx dnsContext
dnsFilter *dnsfilter.Dnsfilter
dnsServer *dnsforward.Server
dhcpServer *dhcpd.Server
httpServer *http.Server
httpsServer HTTPSServer
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
Users []User `yaml:"users"` // Users that can access HTTP server
@@ -296,41 +287,41 @@ func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
config.clients.WriteDiskConfig(&config.Clients)
Context.clients.WriteDiskConfig(&config.Clients)
if config.auth != nil {
config.Users = config.auth.GetUsers()
}
if config.stats != nil {
if Context.stats != nil {
sdc := stats.DiskConfig{}
config.stats.WriteDiskConfig(&sdc)
Context.stats.WriteDiskConfig(&sdc)
config.DNS.StatsInterval = sdc.Interval
}
if config.queryLog != nil {
if Context.queryLog != nil {
dc := querylog.DiskConfig{}
config.queryLog.WriteDiskConfig(&dc)
Context.queryLog.WriteDiskConfig(&dc)
config.DNS.QueryLogEnabled = dc.Enabled
config.DNS.QueryLogInterval = dc.Interval
config.DNS.QueryLogMemSize = dc.MemSize
}
if config.dnsFilter != nil {
if Context.dnsFilter != nil {
c := dnsfilter.Config{}
config.dnsFilter.WriteDiskConfig(&c)
Context.dnsFilter.WriteDiskConfig(&c)
config.DNS.DnsfilterConf = c
}
if config.dnsServer != nil {
if Context.dnsServer != nil {
c := dnsforward.FilteringConfig{}
config.dnsServer.WriteDiskConfig(&c)
Context.dnsServer.WriteDiskConfig(&c)
config.DNS.FilteringConfig = c
}
if config.dhcpServer != nil {
if Context.dhcpServer != nil {
c := dhcpd.ServerConfig{}
config.dhcpServer.WriteDiskConfig(&c)
Context.dhcpServer.WriteDiskConfig(&c)
config.DHCP = c
}

View File

@@ -93,8 +93,8 @@ func getDNSAddresses() []string {
func handleStatus(w http.ResponseWriter, r *http.Request) {
c := dnsforward.FilteringConfig{}
if config.dnsServer != nil {
config.dnsServer.WriteDiskConfig(&c)
if Context.dnsServer != nil {
Context.dnsServer.WriteDiskConfig(&c)
}
data := map[string]interface{}{
"dns_addresses": getDNSAddresses(),
@@ -154,7 +154,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
return
}
config.dnsServer.ServeHTTP(w, r)
Context.dnsServer.ServeHTTP(w, r)
}
// ------------------------

View File

@@ -235,13 +235,19 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port
initDNSServer()
err = startDNSServer()
if err != nil {
err = initDNSServer()
var err2 error
if err == nil {
err2 = startDNSServer()
}
if err != nil || err2 != nil {
config.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
} else {
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err2)
}
return
}
@@ -261,7 +267,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP {
go func() {
_ = config.httpServer.Shutdown(context.TODO())
_ = Context.httpServer.Shutdown(context.TODO())
}()
}

View File

@@ -80,7 +80,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if config.httpsServer.server != nil {
if Context.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
@@ -110,7 +110,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if config.httpsServer.server != nil {
if Context.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
@@ -145,12 +145,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
if restartHTTPS {
go func() {
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
config.httpsServer.cond.L.Lock()
config.httpsServer.cond.Broadcast()
if config.httpsServer.server != nil {
config.httpsServer.server.Shutdown(context.TODO())
Context.httpsServer.cond.L.Lock()
Context.httpsServer.cond.Broadcast()
if Context.httpsServer.server != nil {
Context.httpsServer.server.Shutdown(context.TODO())
}
config.httpsServer.cond.L.Unlock()
Context.httpsServer.cond.L.Unlock()
}()
}
}

View File

@@ -10,12 +10,12 @@ func startDHCPServer() error {
return nil
}
err := config.dhcpServer.Init(config.DHCP)
err := Context.dhcpServer.Init(config.DHCP)
if err != nil {
return errorx.Decorate(err, "Couldn't init DHCP server")
}
err = config.dhcpServer.Start()
err = Context.dhcpServer.Start()
if err != nil {
return errorx.Decorate(err, "Couldn't start DHCP server")
}
@@ -27,7 +27,7 @@ func stopDHCPServer() error {
return nil
}
err := config.dhcpServer.Stop()
err := Context.dhcpServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop DHCP server")
}

View File

@@ -15,11 +15,6 @@ import (
"github.com/joomcode/errorx"
)
type dnsContext struct {
rdns *RDNS
whois *Whois
}
// Called by other modules when configuration is changed
func onConfigModified() {
_ = config.write()
@@ -28,12 +23,12 @@ func onConfigModified() {
// initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
func initDNSServer() {
func initDNSServer() error {
baseDir := config.getDataDir()
err := os.MkdirAll(baseDir, 0755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
}
statsConf := stats.Config{
@@ -42,9 +37,9 @@ func initDNSServer() {
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
}
config.stats, err = stats.New(statsConf)
Context.stats, err = stats.New(statsConf)
if err != nil {
log.Fatal("Couldn't initialize statistics module")
return fmt.Errorf("Couldn't initialize statistics module")
}
conf := querylog.Config{
Enabled: config.DNS.QueryLogEnabled,
@@ -54,7 +49,7 @@ func initDNSServer() {
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
}
config.queryLog = querylog.New(conf)
Context.queryLog = querylog.New(conf)
filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost
@@ -64,22 +59,28 @@ func initDNSServer() {
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
config.dnsFilter = dnsfilter.New(&filterConf, nil)
Context.dnsFilter = dnsfilter.New(&filterConf, nil)
config.dnsServer = dnsforward.NewServer(config.dnsFilter, config.stats, config.queryLog)
Context.dnsServer = dnsforward.NewServer(Context.dnsFilter, Context.stats, Context.queryLog)
dnsConfig := generateServerConfig()
err = Context.dnsServer.Prepare(&dnsConfig)
if err != nil {
return fmt.Errorf("dnsServer.Prepare: %s", err)
}
sessFilename := filepath.Join(baseDir, "sessions.db")
config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
config.Users = nil
config.dnsctx.rdns = InitRDNS(&config.clients)
config.dnsctx.whois = initWhois(&config.clients)
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients)
initFiltering()
return nil
}
func isRunning() bool {
return config.dnsServer != nil && config.dnsServer.IsRunning()
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}
// nolint (gocyclo)
@@ -145,14 +146,14 @@ func onDNSRequest(d *proxy.DNSContext) {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
config.dnsctx.rdns.Begin(ip)
Context.rdns.Begin(ip)
}
if isPublicIP(ipAddr) {
config.dnsctx.whois.Begin(ip)
Context.whois.Begin(ip)
}
}
func generateServerConfig() (dnsforward.ServerConfig, error) {
func generateServerConfig() dnsforward.ServerConfig {
newconfig := dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
@@ -171,11 +172,11 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetUpstreamsByClient = getUpstreamsByClient
return newconfig, nil
return newconfig
}
func getUpstreamsByClient(clientAddr string) []string {
c, ok := config.clients.Find(clientAddr)
c, ok := Context.clients.Find(clientAddr)
if !ok {
return []string{}
}
@@ -192,7 +193,7 @@ func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteri
return
}
c, ok := config.clients.Find(clientAddr)
c, ok := Context.clients.Find(clientAddr)
if !ok {
return
}
@@ -220,12 +221,7 @@ func startDNSServer() error {
enableFilters(false)
newconfig, err := generateServerConfig()
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
err = config.dnsServer.Start(&newconfig)
err := Context.dnsServer.Start()
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
@@ -233,14 +229,14 @@ func startDNSServer() error {
startFiltering()
const topClientsNumber = 100 // the number of clients to get
topClients := config.stats.GetTopClientsIP(topClientsNumber)
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
for _, ip := range topClients {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
config.dnsctx.rdns.Begin(ip)
Context.rdns.Begin(ip)
}
if isPublicIP(ipAddr) {
config.dnsctx.whois.Begin(ip)
Context.whois.Begin(ip)
}
}
@@ -248,11 +244,8 @@ func startDNSServer() error {
}
func reconfigureDNSServer() error {
newconfig, err := generateServerConfig()
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
err = config.dnsServer.Reconfigure(&newconfig)
newconfig := generateServerConfig()
err := Context.dnsServer.Reconfigure(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
@@ -261,26 +254,22 @@ func reconfigureDNSServer() error {
}
func stopDNSServer() error {
if !isRunning() {
return nil
}
err := config.dnsServer.Stop()
err := Context.dnsServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
}
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
config.dnsServer.Close()
Context.dnsServer.Close()
config.dnsFilter.Close()
config.dnsFilter = nil
Context.dnsFilter.Close()
Context.dnsFilter = nil
config.stats.Close()
config.stats = nil
Context.stats.Close()
Context.stats = nil
config.queryLog.Close()
config.queryLog = nil
Context.queryLog.Close()
Context.queryLog = nil
config.auth.Close()
config.auth = nil

View File

@@ -1,17 +0,0 @@
package home
import (
"os"
"testing"
)
func TestResolveRDNS(t *testing.T) {
_ = os.RemoveAll(config.getDataDir())
defer func() { _ = os.RemoveAll(config.getDataDir()) }()
config.DNS.BindHost = "1.1.1.1"
initDNSServer()
if r := config.dnsctx.rdns.resolve("1.1.1.1"); r != "one.one.one.one" {
t.Errorf("resolveRDNS(): %s", r)
}
}

View File

@@ -514,5 +514,5 @@ func enableFilters(async bool) {
}
}
_ = config.dnsFilter.SetFilters(filters, async)
_ = Context.dnsFilter.SetFilters(filters, async)
}

View File

@@ -16,7 +16,6 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
@@ -118,7 +117,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
return
}
// enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
@@ -273,14 +272,8 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
return con, err
}
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
resolverAddr := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
addrs, e := Context.dnsServer.Resolve(host)
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
if e != nil {
return nil, e
}

View File

@@ -21,6 +21,10 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/gobuffalo/packr"
@@ -40,6 +44,23 @@ var (
const versionCheckPeriod = time.Hour * 8
// Global context
type homeContext struct {
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *Whois // WHOIS module
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
}
// Context - a global context object
var Context homeContext
// Main is the entry point
func Main(version string, channel string) {
// Init update-related global variables
@@ -122,8 +143,8 @@ func run(args options) {
config.DHCP.WorkDir = config.ourWorkingDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
config.dhcpServer = dhcpd.Create(config.DHCP)
config.clients.Init(config.Clients, config.dhcpServer)
Context.dhcpServer = dhcpd.Create(config.DHCP)
Context.clients.Init(config.Clients, Context.dhcpServer)
config.Clients = nil
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
@@ -146,7 +167,10 @@ func run(args options) {
log.Fatal(err)
}
initDNSServer()
err = initDNSServer()
if err != nil {
log.Fatalf("%s", err)
}
go func() {
err = startDNSServer()
if err != nil {
@@ -178,21 +202,21 @@ func run(args options) {
registerInstallHandlers()
}
config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
// for https, we have a separate goroutine loop
go httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for !config.httpsServer.shutdown {
for !Context.httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
config.httpServer = &http.Server{
Context.httpServer = &http.Server{
Addr: address,
}
err := config.httpServer.ListenAndServe()
err := Context.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
@@ -205,14 +229,14 @@ func run(args options) {
}
func httpServerLoop() {
for !config.httpsServer.shutdown {
config.httpsServer.cond.L.Lock()
for !Context.httpsServer.shutdown {
Context.httpsServer.cond.L.Lock()
// this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false ||
config.TLS.PortHTTPS == 0 ||
len(config.TLS.PrivateKeyData) == 0 ||
len(config.TLS.CertificateChainData) == 0 { // sleep until necessary data is supplied
config.httpsServer.cond.Wait()
Context.httpsServer.cond.Wait()
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
// validate current TLS config and update warnings (it could have been loaded from file)
@@ -236,10 +260,10 @@ func httpServerLoop() {
cleanupAlways()
log.Fatal(err)
}
config.httpsServer.cond.L.Unlock()
Context.httpsServer.cond.L.Unlock()
// prepare HTTPS server
config.httpsServer.server = &http.Server{
Context.httpsServer.server = &http.Server{
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
@@ -248,7 +272,7 @@ func httpServerLoop() {
}
printHTTPAddresses("https")
err = config.httpsServer.server.ListenAndServeTLS("", "")
err = Context.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
@@ -326,11 +350,10 @@ func configureLogger(args options) {
ls.LogFile = args.logFile
}
level := log.INFO
// log.SetLevel(log.INFO) - default
if ls.Verbose {
level = log.DEBUG
log.SetLevel(log.DEBUG)
}
log.SetLevel(level)
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing else is configured
@@ -378,11 +401,11 @@ func cleanup() {
// Stop HTTP server, possibly waiting for all active connections to be closed
func stopHTTPServer() {
log.Info("Stopping HTTP server...")
config.httpsServer.shutdown = true
if config.httpsServer.server != nil {
config.httpsServer.server.Shutdown(context.TODO())
Context.httpsServer.shutdown = true
if Context.httpsServer.server != nil {
Context.httpsServer.server.Shutdown(context.TODO())
}
config.httpServer.Shutdown(context.TODO())
Context.httpServer.Shutdown(context.TODO())
log.Info("Stopped HTTP server")
}

154
home/home_test.go Normal file
View File

@@ -0,0 +1,154 @@
package home
import (
"context"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/stretchr/testify/assert"
)
const yamlConf = `bind_host: 127.0.0.1
bind_port: 3000
users: []
language: en
rlimit_nofile: 0
web_session_ttl: 720
dns:
bind_host: 127.0.0.1
port: 5354
statistics_interval: 90
querylog_enabled: true
querylog_interval: 90
querylog_memsize: 0
protection_enabled: true
blocking_mode: null_ip
blocked_response_ttl: 0
ratelimit: 100
ratelimit_whitelist: []
refuse_any: false
bootstrap_dns:
- 1.1.1.1:53
all_servers: false
allowed_clients: []
disallowed_clients: []
blocked_hosts: []
parental_block_host: family-block.dns.adguard.com
safebrowsing_block_host: standard-block.dns.adguard.com
cache_size: 0
upstream_dns:
- https://1.1.1.1/dns-query
filtering_enabled: true
filters_update_interval: 168
parental_sensitivity: 13
parental_enabled: true
safesearch_enabled: false
safebrowsing_enabled: false
safebrowsing_cache_size: 1048576
safesearch_cache_size: 1048576
parental_cache_size: 1048576
cache_time: 30
rewrites: []
blocked_services: []
tls:
enabled: false
server_name: www.example.com
force_https: false
port_https: 443
port_dns_over_tls: 853
certificate_chain: ""
private_key: ""
certificate_path: ""
private_key_path: ""
filters:
- enabled: true
url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
name: AdGuard Simplified Domain Names filter
id: 1
- enabled: false
url: https://hosts-file.net/ad_servers.txt
name: hpHosts - Ad and Tracking servers only
id: 2
- enabled: false
url: https://adaway.org/hosts.txt
name: adaway
id: 3
user_rules:
- ""
dhcp:
enabled: false
interface_name: ""
gateway_ip: ""
subnet_mask: ""
range_start: ""
range_end: ""
lease_duration: 86400
icmp_timeout_msec: 1000
clients: []
log_file: ""
verbose: false
schema_version: 5
`
// . Create a configuration file
// . Start AGH instance
// . Check Web server
// . Check DNS server
// . Wait until the filters are downloaded
// . Stop and cleanup
func TestHome(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "AdGuardHome.yaml")
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0644) == nil)
fn, _ = filepath.Abs(fn)
args := options{}
args.configFilename = fn
args.workDir = dir
go run(args)
var err error
var resp *http.Response
h := http.Client{}
for i := 0; i != 5; i++ {
resp, err = h.Get("http://127.0.0.1:3000/")
if err == nil && resp.StatusCode != 404 {
break
}
time.Sleep(1 * time.Second)
}
assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, 200, resp.StatusCode)
resp, err = h.Get("http://127.0.0.1:3000/control/status")
assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, 200, resp.StatusCode)
r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
assert.Truef(t, err == nil, "%s", err)
haveIP := len(addrs) != 0
assert.True(t, haveIP)
for i := 1; ; i++ {
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
if err == nil && st.Size() != 0 {
break
}
if i == 5 {
assert.True(t, false)
break
}
time.Sleep(1 * time.Second)
}
cleanup()
cleanupAlways()
}

View File

@@ -2,25 +2,20 @@ package home
import (
"encoding/binary"
"fmt"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
const (
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
)
// RDNS - module context
type RDNS struct {
dnsServer *dnsforward.Server
clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
upstream upstream.Upstream // Upstream object for our own DNS server
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
// Contains IP addresses of clients to be resolved by rDNS
// If IP address is resolved, it stays here while it's inside Clients.
@@ -30,25 +25,10 @@ type RDNS struct {
}
// InitRDNS - create module context
func InitRDNS(clients *clientsContainer) *RDNS {
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
r := RDNS{}
r.dnsServer = dnsServer
r.clients = clients
var err error
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
opts := upstream.Options{
Timeout: rdnsTimeout,
}
r.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
if err != nil {
log.Error("upstream.AddressToUpstream: %s", err)
return nil
}
cconf := cache.Config{}
cconf.EnableLRU = true
@@ -109,7 +89,7 @@ func (r *RDNS) resolve(ip string) string {
return ""
}
resp, err := r.upstream.Exchange(&req)
resp, err := r.dnsServer.Exchange(&req)
if err != nil {
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
return ""
@@ -144,6 +124,6 @@ func (r *RDNS) workerLoop() {
continue
}
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

21
home/rdns_test.go Normal file
View File

@@ -0,0 +1,21 @@
package home
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/stretchr/testify/assert"
)
func TestResolveRDNS(t *testing.T) {
dns := &dnsforward.Server{}
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
err := dns.Prepare(conf)
assert.True(t, err == nil, "%s", err)
clients := &clientsContainer{}
rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1")
assert.True(t, r == "one.one.one.one", "%s", r)
}