Merge remote-tracking branch 'origin/master' into dhcp6

# Conflicts:
#	openapi/openapi.yaml
This commit is contained in:
Simon Zolin
2020-05-29 14:28:18 +03:00
231 changed files with 17877 additions and 13625 deletions

View File

@@ -380,11 +380,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
}
}
} else if r.URL.Path == "/favicon.png" ||
strings.HasPrefix(r.URL.Path, "/login.") ||
strings.HasPrefix(r.URL.Path, "/__locales/") {
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
strings.HasPrefix(r.URL.Path, "/login.") {
// process as usual
// no additional auth requirements
} else if Context.auth != nil && Context.auth.AuthRequired() {
// redirect to login page if not authenticated
ok := false

View File

@@ -11,17 +11,18 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
const (
clientsUpdatePeriod = 1 * time.Hour
clientsUpdatePeriod = 10 * time.Minute
)
var webHandlersRegistered = false
@@ -41,11 +42,12 @@ type Client struct {
BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests
// Upstream objects:
// Custom upstream config for this client
// nil: not yet initialized
// not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used
upstreamObjects []upstream.Upstream
upstreamConfig *proxy.UpstreamConfig
}
type clientSource uint
@@ -276,16 +278,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
return c, true
}
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
a2 := make([]upstream.Upstream, len(a))
copy(a2, a)
return a2
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -294,22 +290,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
return nil
}
if c.upstreamObjects == nil {
c.upstreamObjects = make([]upstream.Upstream, 0)
for _, us := range c.Upstreams {
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
if err != nil {
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
continue
}
c.upstreamObjects = append(c.upstreamObjects, u)
if len(c.Upstreams) == 0 {
return nil
}
if c.upstreamConfig == nil {
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err == nil {
c.upstreamConfig = &config
}
}
if len(c.upstreamObjects) == 0 {
return nil
}
return upstreamArrayCopy(c.upstreamObjects)
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
@@ -540,7 +532,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
}
// update upstreams cache
c.upstreamObjects = nil
c.upstreamConfig = nil
*old = c
return nil

View File

@@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
assert.Nil(t, err)
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
assert.Nil(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
}

View File

@@ -78,10 +78,11 @@ type dnsConfig struct {
// time interval for statistics (in days)
StatsInterval uint32 `yaml:"statistics_interval"`
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"` // if true, query log will be written to a file
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
dnsforward.FilteringConfig `yaml:",inline"`
@@ -134,6 +135,7 @@ func initConfig() {
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
config.DNS.QueryLogFileEnabled = true
config.DNS.QueryLogInterval = 90
config.DNS.QueryLogMemSize = 1000
@@ -235,9 +237,10 @@ func (c *configuration) write() error {
}
if Context.queryLog != nil {
dc := querylog.DiskConfig{}
dc := querylog.Config{}
Context.queryLog.WriteDiskConfig(&dc)
config.DNS.QueryLogEnabled = dc.Enabled
config.DNS.QueryLogFileEnabled = dc.FileEnabled
config.DNS.QueryLogInterval = dc.Interval
config.DNS.QueryLogMemSize = dc.MemSize
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP

View File

@@ -189,7 +189,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" {
!strings.HasPrefix(r.URL.Path, "/assets/") {
http.Redirect(w, r, "/install.html", http.StatusFound)
return
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
@@ -41,6 +40,7 @@ func initDNSServer() error {
}
conf := querylog.Config{
Enabled: config.DNS.QueryLogEnabled,
FileEnabled: config.DNS.QueryLogFileEnabled,
BaseDir: baseDir,
Interval: config.DNS.QueryLogInterval,
MemSize: config.DNS.QueryLogMemSize,
@@ -176,7 +176,7 @@ func generateServerConfig() dnsforward.ServerConfig {
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetUpstreamsByClient = getUpstreamsByClient
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
return newconfig
}
@@ -222,10 +222,6 @@ func getDNSAddresses() []string {
return dnsAddresses
}
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
return Context.clients.FindUpstreams(clientAddr)
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)

View File

@@ -461,26 +461,29 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
for {
line, err := r.ReadString('\n')
if err != nil {
break
}
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
//
if line[0] == '!' {
} else if line[0] == '!' {
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else if line[0] == '#' {
//
} else {
rulesCount++
}
if err != nil {
break
}
}
return rulesCount, checksum, name
@@ -595,7 +598,7 @@ func (f *Filtering) updateIntl(filter *filter) (bool, error) {
log.Printf("Filter %d has been updated: %d bytes, %d rules",
filter.ID, total, rulesCount)
if filterName != "" {
if len(filter.Name) == 0 {
filter.Name = filterName
}
filter.RulesCount = rulesCount

View File

@@ -14,6 +14,7 @@ import (
func testStartFilterListener() net.Listener {
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`

View File

@@ -208,6 +208,11 @@ func run(args options) {
}
}
// 'clients' module uses 'dnsfilter' module's static data (dnsfilter.BlockedSvcKnown()),
// so we have to initialize dnsfilter's static data first,
// but also avoid relying on automatic Go init() function
dnsfilter.InitModule()
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
@@ -335,6 +340,8 @@ func requireAdminRights() {
admin, _ := util.HaveAdminRights()
if //noinspection ALL
admin || isdelve.Enabled {
// Don't forget that for this to work you need to add "delve" tag explicitly
// https://stackoverflow.com/questions/47879070/how-can-i-see-if-the-goland-debugger-is-running-in-the-program
return
}
@@ -589,28 +596,34 @@ func printHTTPAddresses(proto string) {
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
}
port := strconv.Itoa(config.BindPort)
if proto == "https" {
port = strconv.Itoa(tlsConf.PortHTTPS)
}
if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {
log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS)
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
}
} else if config.BindHost == "0.0.0.0" {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
return
}
for _, iface := range ifaces {
address = net.JoinHostPort(iface.Addresses[0], strconv.Itoa(config.BindPort))
address = net.JoinHostPort(iface.Addresses[0], port)
log.Printf("Go to %s://%s", proto, address)
}
} else {
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
}
}

View File

@@ -39,7 +39,14 @@ func tlsCreate(conf tlsConfigSettings) *TLSMod {
t.conf = conf
if t.conf.Enabled {
if !t.load() {
return nil
// Something is not valid - return an empty TLS config
return &TLSMod{conf: tlsConfigSettings{
Enabled: conf.Enabled,
ServerName: conf.ServerName,
PortHTTPS: conf.PortHTTPS,
PortDNSOverTLS: conf.PortDNSOverTLS,
AllowUnencryptedDOH: conf.AllowUnencryptedDOH,
}}
}
t.setCertFileTime()
}
@@ -48,13 +55,14 @@ func tlsCreate(conf tlsConfigSettings) *TLSMod {
func (t *TLSMod) load() bool {
if !tlsLoadConfig(&t.conf, &t.status) {
log.Error("failed to load TLS config: %s", t.status.WarningValidation)
return false
}
// validate current TLS config and update warnings (it could have been loaded from file)
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
if !data.ValidPair {
log.Error(data.WarningValidation)
log.Error("failed to validate certificate: %s", data.WarningValidation)
return false
}
t.status = data
@@ -191,7 +199,7 @@ type tlsConfig struct {
tlsConfigStatus `json:",inline"`
}
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) {
t.confLock.Lock()
data := tlsConfig{
tlsConfigSettings: t.conf,

View File

@@ -3,6 +3,7 @@ package home
import (
"context"
"crypto/tls"
golog "log"
"net"
"net/http"
"strconv"
@@ -38,6 +39,17 @@ type Web struct {
portHTTPS int
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
errLogger *golog.Logger
}
// Proxy between Go's "log" and "golibs/log"
type logWriter struct {
}
// HTTP server calls this function to log an error
func (w *logWriter) Write(p []byte) (int, error) {
log.Debug("Web: %s", string(p))
return 0, nil
}
// CreateWeb - create module
@@ -47,6 +59,9 @@ func CreateWeb(conf *WebConfig) *Web {
w := Web{}
w.conf = conf
lw := logWriter{}
w.errLogger = golog.New(&lw, "", 0)
// Initialize and run the admin Web interface
box := packr.NewBox("../build/static")
@@ -115,7 +130,7 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
// Start - start serving HTTP requests
func (web *Web) Start() {
// for https, we have a separate goroutine loop
go web.httpServerLoop()
go web.tlsServerLoop()
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown {
@@ -124,7 +139,8 @@ func (web *Web) Start() {
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
web.httpServer = &http.Server{
Addr: address,
ErrorLog: web.errLogger,
Addr: address,
}
err := web.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
@@ -151,7 +167,7 @@ func (web *Web) Close() {
log.Info("Stopped HTTP server")
}
func (web *Web) httpServerLoop() {
func (web *Web) tlsServerLoop() {
for {
web.httpsServer.cond.L.Lock()
if web.httpsServer.shutdown {
@@ -173,7 +189,8 @@ func (web *Web) httpServerLoop() {
// prepare HTTPS server
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
web.httpsServer.server = &http.Server{
Addr: address,
ErrorLog: web.errLogger,
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
MinVersion: tls.VersionTLS12,

View File

@@ -1,6 +1,7 @@
package home
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
@@ -120,7 +121,7 @@ func (w *Whois) query(target string, serverAddr string) (string, error) {
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := net.DialTimeout("tcp", serverAddr, time.Duration(w.timeoutMsec)*time.Millisecond)
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
if err != nil {
return "", err
}

View File

@@ -2,11 +2,29 @@ package home
import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/stretchr/testify/assert"
)
func prepareTestDNSServer() error {
config.DNS.Port = 1234
Context.dnsServer = dnsforward.NewServer(nil, nil, nil)
conf := &dnsforward.ServerConfig{}
uc, err := proxy.ParseUpstreamsConfig([]string{"1.1.1.1"}, nil, time.Second*5)
if err != nil {
return err
}
conf.UpstreamConfig = &uc
return Context.dnsServer.Prepare(conf)
}
func TestWhois(t *testing.T) {
err := prepareTestDNSServer()
assert.Nil(t, err)
w := Whois{timeoutMsec: 5000}
resp, err := w.queryAll("8.8.8.8")
assert.True(t, err == nil)