Merge remote-tracking branch 'origin/master' into dhcp6
# Conflicts: # openapi/openapi.yaml
This commit is contained in:
@@ -380,11 +380,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
||||
}
|
||||
}
|
||||
|
||||
} else if r.URL.Path == "/favicon.png" ||
|
||||
strings.HasPrefix(r.URL.Path, "/login.") ||
|
||||
strings.HasPrefix(r.URL.Path, "/__locales/") {
|
||||
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
|
||||
strings.HasPrefix(r.URL.Path, "/login.") {
|
||||
// process as usual
|
||||
|
||||
// no additional auth requirements
|
||||
} else if Context.auth != nil && Context.auth.AuthRequired() {
|
||||
// redirect to login page if not authenticated
|
||||
ok := false
|
||||
|
||||
@@ -11,17 +11,18 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
clientsUpdatePeriod = 1 * time.Hour
|
||||
clientsUpdatePeriod = 10 * time.Minute
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
@@ -41,11 +42,12 @@ type Client struct {
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
// Upstream objects:
|
||||
|
||||
// Custom upstream config for this client
|
||||
// nil: not yet initialized
|
||||
// not nil, but empty: initialized, no good upstreams
|
||||
// not nil, not empty: Upstreams ready to be used
|
||||
upstreamObjects []upstream.Upstream
|
||||
upstreamConfig *proxy.UpstreamConfig
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
@@ -276,16 +278,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
|
||||
a2 := make([]upstream.Upstream, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -294,22 +290,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamObjects == nil {
|
||||
c.upstreamObjects = make([]upstream.Upstream, 0)
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
c.upstreamObjects = append(c.upstreamObjects, u)
|
||||
if len(c.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &config
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.upstreamObjects) == 0 {
|
||||
return nil
|
||||
}
|
||||
return upstreamArrayCopy(c.upstreamObjects)
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
@@ -540,7 +532,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
c.upstreamObjects = nil
|
||||
c.upstreamConfig = nil
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
|
||||
@@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add client with upstreams
|
||||
client := Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
}
|
||||
ok, err := clients.Add(client)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, 1, len(config.Upstreams))
|
||||
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
||||
}
|
||||
|
||||
@@ -78,10 +78,11 @@ type dnsConfig struct {
|
||||
// time interval for statistics (in days)
|
||||
StatsInterval uint32 `yaml:"statistics_interval"`
|
||||
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
|
||||
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
|
||||
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
|
||||
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"` // if true, query log will be written to a file
|
||||
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
|
||||
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
@@ -134,6 +135,7 @@ func initConfig() {
|
||||
config.WebSessionTTLHours = 30 * 24
|
||||
|
||||
config.DNS.QueryLogEnabled = true
|
||||
config.DNS.QueryLogFileEnabled = true
|
||||
config.DNS.QueryLogInterval = 90
|
||||
config.DNS.QueryLogMemSize = 1000
|
||||
|
||||
@@ -235,9 +237,10 @@ func (c *configuration) write() error {
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
dc := querylog.DiskConfig{}
|
||||
dc := querylog.Config{}
|
||||
Context.queryLog.WriteDiskConfig(&dc)
|
||||
config.DNS.QueryLogEnabled = dc.Enabled
|
||||
config.DNS.QueryLogFileEnabled = dc.FileEnabled
|
||||
config.DNS.QueryLogInterval = dc.Interval
|
||||
config.DNS.QueryLogMemSize = dc.MemSize
|
||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
|
||||
@@ -189,7 +189,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
||||
|
||||
if Context.firstRun &&
|
||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||
r.URL.Path != "/favicon.png" {
|
||||
!strings.HasPrefix(r.URL.Path, "/assets/") {
|
||||
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
@@ -41,6 +40,7 @@ func initDNSServer() error {
|
||||
}
|
||||
conf := querylog.Config{
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
BaseDir: baseDir,
|
||||
Interval: config.DNS.QueryLogInterval,
|
||||
MemSize: config.DNS.QueryLogMemSize,
|
||||
@@ -176,7 +176,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
return newconfig
|
||||
}
|
||||
|
||||
@@ -222,10 +222,6 @@ func getDNSAddresses() []string {
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
@@ -461,26 +461,29 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
//
|
||||
|
||||
if line[0] == '!' {
|
||||
} else if line[0] == '!' {
|
||||
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
|
||||
} else if line[0] == '#' {
|
||||
//
|
||||
|
||||
} else {
|
||||
rulesCount++
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, checksum, name
|
||||
@@ -595,7 +598,7 @@ func (f *Filtering) updateIntl(filter *filter) (bool, error) {
|
||||
|
||||
log.Printf("Filter %d has been updated: %d bytes, %d rules",
|
||||
filter.ID, total, rulesCount)
|
||||
if filterName != "" {
|
||||
if len(filter.Name) == 0 {
|
||||
filter.Name = filterName
|
||||
}
|
||||
filter.RulesCount = rulesCount
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
func testStartFilterListener() net.Listener {
|
||||
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
`
|
||||
|
||||
21
home/home.go
21
home/home.go
@@ -208,6 +208,11 @@ func run(args options) {
|
||||
}
|
||||
}
|
||||
|
||||
// 'clients' module uses 'dnsfilter' module's static data (dnsfilter.BlockedSvcKnown()),
|
||||
// so we have to initialize dnsfilter's static data first,
|
||||
// but also avoid relying on automatic Go init() function
|
||||
dnsfilter.InitModule()
|
||||
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
@@ -335,6 +340,8 @@ func requireAdminRights() {
|
||||
admin, _ := util.HaveAdminRights()
|
||||
if //noinspection ALL
|
||||
admin || isdelve.Enabled {
|
||||
// Don't forget that for this to work you need to add "delve" tag explicitly
|
||||
// https://stackoverflow.com/questions/47879070/how-can-i-see-if-the-goland-debugger-is-running-in-the-program
|
||||
return
|
||||
}
|
||||
|
||||
@@ -589,28 +596,34 @@ func printHTTPAddresses(proto string) {
|
||||
if Context.tls != nil {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := strconv.Itoa(config.BindPort)
|
||||
if proto == "https" {
|
||||
port = strconv.Itoa(tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if proto == "https" && tlsConf.ServerName != "" {
|
||||
if tlsConf.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", tlsConf.ServerName)
|
||||
} else {
|
||||
log.Printf("Go to https://%s:%d", tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
|
||||
}
|
||||
} else if config.BindHost == "0.0.0.0" {
|
||||
log.Println("AdGuard Home is available on the following addresses:")
|
||||
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
// That's weird, but we'll ignore it
|
||||
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
address = net.JoinHostPort(iface.Addresses[0], strconv.Itoa(config.BindPort))
|
||||
address = net.JoinHostPort(iface.Addresses[0], port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
} else {
|
||||
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
}
|
||||
|
||||
14
home/tls.go
14
home/tls.go
@@ -39,7 +39,14 @@ func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||
t.conf = conf
|
||||
if t.conf.Enabled {
|
||||
if !t.load() {
|
||||
return nil
|
||||
// Something is not valid - return an empty TLS config
|
||||
return &TLSMod{conf: tlsConfigSettings{
|
||||
Enabled: conf.Enabled,
|
||||
ServerName: conf.ServerName,
|
||||
PortHTTPS: conf.PortHTTPS,
|
||||
PortDNSOverTLS: conf.PortDNSOverTLS,
|
||||
AllowUnencryptedDOH: conf.AllowUnencryptedDOH,
|
||||
}}
|
||||
}
|
||||
t.setCertFileTime()
|
||||
}
|
||||
@@ -48,13 +55,14 @@ func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||
|
||||
func (t *TLSMod) load() bool {
|
||||
if !tlsLoadConfig(&t.conf, &t.status) {
|
||||
log.Error("failed to load TLS config: %s", t.status.WarningValidation)
|
||||
return false
|
||||
}
|
||||
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
|
||||
if !data.ValidPair {
|
||||
log.Error(data.WarningValidation)
|
||||
log.Error("failed to validate certificate: %s", data.WarningValidation)
|
||||
return false
|
||||
}
|
||||
t.status = data
|
||||
@@ -191,7 +199,7 @@ type tlsConfig struct {
|
||||
tlsConfigStatus `json:",inline"`
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
t.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettings: t.conf,
|
||||
|
||||
25
home/web.go
25
home/web.go
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
golog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -38,6 +39,17 @@ type Web struct {
|
||||
portHTTPS int
|
||||
httpServer *http.Server // HTTP module
|
||||
httpsServer HTTPSServer // HTTPS module
|
||||
errLogger *golog.Logger
|
||||
}
|
||||
|
||||
// Proxy between Go's "log" and "golibs/log"
|
||||
type logWriter struct {
|
||||
}
|
||||
|
||||
// HTTP server calls this function to log an error
|
||||
func (w *logWriter) Write(p []byte) (int, error) {
|
||||
log.Debug("Web: %s", string(p))
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// CreateWeb - create module
|
||||
@@ -47,6 +59,9 @@ func CreateWeb(conf *WebConfig) *Web {
|
||||
w := Web{}
|
||||
w.conf = conf
|
||||
|
||||
lw := logWriter{}
|
||||
w.errLogger = golog.New(&lw, "", 0)
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../build/static")
|
||||
|
||||
@@ -115,7 +130,7 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||
// Start - start serving HTTP requests
|
||||
func (web *Web) Start() {
|
||||
// for https, we have a separate goroutine loop
|
||||
go web.httpServerLoop()
|
||||
go web.tlsServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.shutdown {
|
||||
@@ -124,7 +139,8 @@ func (web *Web) Start() {
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
|
||||
web.httpServer = &http.Server{
|
||||
Addr: address,
|
||||
ErrorLog: web.errLogger,
|
||||
Addr: address,
|
||||
}
|
||||
err := web.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
@@ -151,7 +167,7 @@ func (web *Web) Close() {
|
||||
log.Info("Stopped HTTP server")
|
||||
}
|
||||
|
||||
func (web *Web) httpServerLoop() {
|
||||
func (web *Web) tlsServerLoop() {
|
||||
for {
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.shutdown {
|
||||
@@ -173,7 +189,8 @@ func (web *Web) httpServerLoop() {
|
||||
// prepare HTTPS server
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
|
||||
web.httpsServer.server = &http.Server{
|
||||
Addr: address,
|
||||
ErrorLog: web.errLogger,
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@@ -120,7 +121,7 @@ func (w *Whois) query(target string, serverAddr string) (string, error) {
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
conn, err := net.DialTimeout("tcp", serverAddr, time.Duration(w.timeoutMsec)*time.Millisecond)
|
||||
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -2,11 +2,29 @@ package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func prepareTestDNSServer() error {
|
||||
config.DNS.Port = 1234
|
||||
Context.dnsServer = dnsforward.NewServer(nil, nil, nil)
|
||||
conf := &dnsforward.ServerConfig{}
|
||||
uc, err := proxy.ParseUpstreamsConfig([]string{"1.1.1.1"}, nil, time.Second*5)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conf.UpstreamConfig = &uc
|
||||
return Context.dnsServer.Prepare(conf)
|
||||
}
|
||||
|
||||
func TestWhois(t *testing.T) {
|
||||
err := prepareTestDNSServer()
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := Whois{timeoutMsec: 5000}
|
||||
resp, err := w.queryAll("8.8.8.8")
|
||||
assert.True(t, err == nil)
|
||||
|
||||
Reference in New Issue
Block a user