Fix #1069 install: check static ip

Squashed commit of the following:

commit 57466233cb
Merge: 2df5f281 867bf545
Author: Andrey Meshkov <am@adguard.com>
Date:   Thu Feb 13 18:39:15 2020 +0300

    Merge branch 'master' into 1069-install-static-ip

commit 2df5f281c4
Author: Andrey Meshkov <am@adguard.com>
Date:   Thu Feb 13 18:35:54 2020 +0300

    *: lang fix

commit b4649a6b27
Merge: c2785253 f61d5f0f
Author: Andrey Meshkov <am@adguard.com>
Date:   Thu Feb 13 16:47:30 2020 +0300

    *(home): fixed issues with setting static IP on Mac

commit c27852537d
Author: Andrey Meshkov <am@adguard.com>
Date:   Thu Feb 13 14:14:30 2020 +0300

    +(dhcpd): added static IP for MacOS

commit f61d5f0f85
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Thu Feb 13 14:13:35 2020 +0300

    + client: show confirm before setting static IP

commit 7afa16fbe7
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Thu Feb 13 13:51:52 2020 +0300

    - client: fix text

commit 019bff0851
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Thu Feb 13 13:49:16 2020 +0300

    - client: pass all params to the check_config request

commit 194bed72f5
Author: Andrey Meshkov <am@adguard.com>
Date:   Wed Feb 12 17:12:16 2020 +0300

    *: fix home_test

commit 9359f6b55f
Merge: ae299058 c5ca2a77
Author: Andrey Meshkov <am@adguard.com>
Date:   Wed Feb 12 15:54:54 2020 +0300

    Merge with master

commit ae2990582d
Author: Andrey Meshkov <am@adguard.com>
Date:   Wed Feb 12 15:53:36 2020 +0300

    *(global): refactoring - moved runtime properties to Context

commit d8d48c5386
Author: Andrey Meshkov <am@adguard.com>
Date:   Wed Feb 12 15:04:25 2020 +0300

    *(dhcpd): refactoring, use dhcpd/network_utils where possible

commit 8d039c572f
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Fri Feb 7 18:37:39 2020 +0300

    - client: fix button position

commit 26c47e59dd
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Fri Feb 7 18:08:56 2020 +0300

    - client: fix static ip description

commit cb12babc46
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Feb 7 17:08:39 2020 +0300

    *: lower log level for some commands

commit d9001ff848
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Feb 7 16:17:59 2020 +0300

    *(documentation): updated openapi

commit 1d213d53c8
Merge: 8406d7d2 80861860
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Feb 7 15:16:46 2020 +0300

    *: merge with master

commit 8406d7d288
Author: Ildar Kamalov <i.kamalov@adguard.com>
Date:   Fri Jan 31 16:52:22 2020 +0300

    - client: fix locales

commit fb476b0117
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 31 13:29:03 2020 +0300

    linter

commit 84b5708e71
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 31 13:27:53 2020 +0300

    linter

commit 143a86a28a
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 31 13:26:47 2020 +0300

    linter

... and 7 more commits
This commit is contained in:
Andrey Meshkov
2020-02-13 18:42:07 +03:00
parent 867bf5457f
commit 7a3eda02ce
38 changed files with 1319 additions and 781 deletions

View File

@@ -152,7 +152,7 @@ func (a *Auth) addSession(data []byte, s *session) {
a.sessions[name] = s
a.lock.Unlock()
if a.storeSession(data, s) {
log.Info("Auth: created session %s: expire=%d", name, s.expire)
log.Debug("Auth: created session %s: expire=%d", name, s.expire)
}
}
@@ -307,7 +307,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
cookie := config.auth.httpCookie(req)
cookie := Context.auth.httpCookie(req)
if len(cookie) == 0 {
log.Info("Auth: invalid user name or password: name='%s'", req.Name)
time.Sleep(1 * time.Second)
@@ -328,7 +328,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
cookie := r.Header.Get("Cookie")
sess := parseCookie(cookie)
config.auth.RemoveSession(sess)
Context.auth.RemoveSession(sess)
w.Header().Set("Location", "/login.html")
@@ -365,10 +365,10 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
if r.URL.Path == "/login.html" {
// redirect to dashboard if already authenticated
authRequired := config.auth != nil && config.auth.AuthRequired()
authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
r := config.auth.CheckSession(cookie.Value)
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
@@ -383,12 +383,12 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
strings.HasPrefix(r.URL.Path, "/__locales/") {
// process as usual
} else if config.auth != nil && config.auth.AuthRequired() {
} else if Context.auth != nil && Context.auth.AuthRequired() {
// redirect to login page if not authenticated
ok := false
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
r := config.auth.CheckSession(cookie.Value)
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
ok = true
} else if r < 0 {
@@ -398,7 +398,7 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
// there's no Cookie, check Basic authentication
user, pass, ok2 := r.BasicAuth()
if ok2 {
u := config.auth.UserFind(user, pass)
u := Context.auth.UserFind(user, pass)
if len(u.Name) != 0 {
ok = true
} else {
@@ -474,7 +474,7 @@ func (a *Auth) GetCurrentUser(r *http.Request) User {
// there's no Cookie, check Basic authentication
user, pass, ok := r.BasicAuth()
if ok {
u := config.auth.UserFind(user, pass)
u := Context.auth.UserFind(user, pass)
return u
}
return User{}

View File

@@ -100,7 +100,7 @@ func TestAuthHTTP(t *testing.T) {
users := []User{
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
config.auth = InitAuth(fn, users, 60)
Context.auth = InitAuth(fn, users, 60)
handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) {
@@ -129,7 +129,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
// perform login
cookie := config.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.True(t, cookie != "")
// get /
@@ -173,5 +173,5 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
config.auth.Close()
Context.auth.Close()
}

View File

@@ -44,19 +44,6 @@ type configuration struct {
// It's reset after config is parsed
fileData []byte
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
disableUpdate bool // If set, don't check for updates
appSignalChannel chan os.Signal
controlLock sync.Mutex
transport *http.Transport
client *http.Client
auth *Auth // HTTP authentication module
// cached version.json to avoid hammering github.io for each page reload
versionCheckJSON []byte
versionCheckLastTime time.Time
@@ -152,9 +139,8 @@ type tlsConfig struct {
// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{
ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000,
BindHost: "0.0.0.0",
BindPort: 3000,
BindHost: "0.0.0.0",
DNS: dnsConfig{
BindHost: "0.0.0.0",
Port: 53,
@@ -185,14 +171,6 @@ var config = configuration{
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.transport = &http.Transport{
DialContext: customDialContext,
}
config.client = &http.Client{
Timeout: time.Minute * 5,
Transport: config.transport,
}
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
@@ -209,24 +187,19 @@ func initConfig() {
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(config.ourConfigFilename)
configFile, err := filepath.EvalSymlinks(Context.configFilename)
if err != nil {
if !os.IsNotExist(err) {
log.Error("unexpected error while config file path evaluation: %s", err)
}
configFile = config.ourConfigFilename
configFile = Context.configFilename
}
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(config.ourWorkingDir, configFile)
configFile = filepath.Join(Context.workDir, configFile)
}
return configFile
}
// getDataDir returns path to the directory where we store databases and filters
func (c *configuration) getDataDir() string {
return filepath.Join(c.ourWorkingDir, dataDir)
}
// getLogSettings reads logging settings from the config file.
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
func getLogSettings() logSettings {
@@ -292,8 +265,8 @@ func (c *configuration) write() error {
Context.clients.WriteDiskConfig(&config.Clients)
if config.auth != nil {
config.Users = config.auth.GetUsers()
if Context.auth != nil {
config.Users = Context.auth.GetUsers()
}
if Context.stats != nil {

View File

@@ -3,7 +3,13 @@ package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/log"
@@ -54,8 +60,7 @@ func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
ifaces, e := getValidNetInterfacesForWeb()
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
@@ -66,7 +71,6 @@ func getDNSAddresses() []string {
addDNSAddress(&dnsAddresses, addr)
}
}
} else {
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
}
@@ -129,7 +133,7 @@ type profileJSON struct {
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{}
u := config.auth.GetCurrentUser(r)
u := Context.auth.GetCurrentUser(r)
pj.Name = u.Name
data, err := json.Marshal(pj)
@@ -180,3 +184,118 @@ func registerControlHandlers() {
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
}
// ----------------------------------
// helper functions for HTTP handlers
// ----------------------------------
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
log.Debug("%s %v", r.Method, r.URL)
if r.Method != method {
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
if method == "POST" || method == "PUT" || method == "DELETE" {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
}
handler(w, r)
}
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("POST", handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("GET", handler)
}
// Bridge between http.Handler object and Go function
type httpHandler struct {
handler func(http.ResponseWriter, *http.Request)
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r)
}
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
h := httpHandler{}
h.handler = ensure(method, handler)
return &h
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !Context.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
// it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
return
}
// enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// no port in host
host = r.Host
}
// construct new URL to redirect to
newURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r)
}
}
type postInstallHandlerStruct struct {
handler http.Handler
}
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
}

View File

@@ -210,9 +210,9 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
}
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
config.controlLock.Unlock()
Context.controlLock.Unlock()
nUpdated, err := refreshFilters()
config.controlLock.Lock()
Context.controlLock.Lock()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return

View File

@@ -13,6 +13,10 @@ import (
"runtime"
"strconv"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/golibs/log"
)
@@ -22,13 +26,21 @@ type firstRunData struct {
Interfaces map[string]interface{} `json:"interfaces"`
}
type netInterfaceJSON struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Flags string `json:"flags"`
}
// Get initial installation settings
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := firstRunData{}
data.WebPort = 80
data.DNSPort = 53
ifaces, err := getValidNetInterfacesForWeb()
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
@@ -36,7 +48,14 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data.Interfaces = make(map[string]interface{})
for _, iface := range ifaces {
data.Interfaces[iface.Name] = iface
ifaceJSON := netInterfaceJSON{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr,
Addresses: iface.Addresses,
Flags: iface.Flags,
}
data.Interfaces[iface.Name] = ifaceJSON
}
w.Header().Set("Content-Type", "application/json")
@@ -53,17 +72,24 @@ type checkConfigReqEnt struct {
Autofix bool `json:"autofix"`
}
type checkConfigReq struct {
Web checkConfigReqEnt `json:"web"`
DNS checkConfigReqEnt `json:"dns"`
Web checkConfigReqEnt `json:"web"`
DNS checkConfigReqEnt `json:"dns"`
SetStaticIP bool `json:"set_static_ip"`
}
type checkConfigRespEnt struct {
Status string `json:"status"`
CanAutofix bool `json:"can_autofix"`
}
type staticIPJSON struct {
Static string `json:"static"`
IP string `json:"ip"`
Error string `json:"error"`
}
type checkConfigResp struct {
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
StaticIP staticIPJSON `json:"static_ip"`
}
// Check if ports are available, respond with results
@@ -77,16 +103,16 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
}
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port)
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
if err != nil {
respData.Web.Status = fmt.Sprintf("%v", err)
}
}
if reqData.DNS.Port != 0 {
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
if errorIsAddrInUse(err) {
if util.ErrorIsAddrInUse(err) {
canAutofix := checkDNSStubListener()
if canAutofix && reqData.DNS.Autofix {
@@ -95,7 +121,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
log.Error("Couldn't disable DNSStubListener: %s", err)
}
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
canAutofix = false
}
@@ -103,11 +129,13 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
}
if err == nil {
err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
}
if err != nil {
respData.DNS.Status = fmt.Sprintf("%v", err)
} else {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
}
}
@@ -119,6 +147,46 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
}
}
// handleStaticIP - handles static IP request
// It either checks if we have a static IP
// Or if set=true, it tries to set it
func handleStaticIP(ip string, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := util.GetInterfaceByIP(ip)
resp.Static = "no"
if len(interfaceName) == 0 {
resp.Static = "error"
resp.Error = fmt.Sprintf("Couldn't find network interface by IP %s", ip)
return resp
}
if set {
// Try to set static IP for the specified interface
err := dhcpd.SetStaticIP(interfaceName)
if err != nil {
resp.Static = "error"
resp.Error = err.Error()
return resp
}
}
// Fallthrough here even if we set static IP
// Check if we have a static IP and return the details
isStaticIP, err := dhcpd.HasStaticIP(interfaceName)
if err != nil {
resp.Static = "error"
resp.Error = err.Error()
} else {
if isStaticIP {
resp.Static = "yes"
}
resp.IP = util.GetSubnet(interfaceName)
}
return resp
}
// Check if DNSStubListener is active
func checkDNSStubListener() bool {
if runtime.GOOS != "linux" {
@@ -129,7 +197,7 @@ func checkDNSStubListener() bool {
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Error("command %s has failed: %v code:%d",
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
@@ -138,7 +206,7 @@ func checkDNSStubListener() bool {
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Error("command %s has failed: %v code:%d",
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
@@ -228,7 +296,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
// validate that hosts and ports are bindable
if restartHTTP {
err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
@@ -236,13 +304,13 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
}
}
err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@@ -251,7 +319,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
var curConfig configuration
copyInstallSettings(&curConfig, &config)
config.firstRun = false
Context.firstRun = false
config.BindHost = newSettings.Web.IP
config.BindPort = newSettings.Web.Port
config.DNS.BindHost = newSettings.DNS.IP
@@ -266,7 +334,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
}
}
if err != nil || err2 != nil {
config.firstRun = true
Context.firstRun = true
copyInstallSettings(&config, &curConfig)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't initialize DNS server: %s", err)
@@ -278,11 +346,11 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
u := User{}
u.Name = newSettings.Username
config.auth.UserAdd(&u, newSettings.Password)
Context.auth.UserAdd(&u, newSettings.Password)
err = config.write()
if err != nil {
config.firstRun = true
Context.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return

View File

@@ -20,6 +20,8 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
@@ -84,7 +86,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
alreadyRunning = true
}
if !alreadyRunning {
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return
@@ -114,7 +116,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
alreadyRunning = true
}
if !alreadyRunning {
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return

View File

@@ -17,6 +17,8 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
)
@@ -64,7 +66,7 @@ type getVersionJSONRequest struct {
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if config.disableUpdate {
if Context.disableUpdate {
return
}
@@ -77,10 +79,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now()
if !req.RecheckNow {
config.controlLock.Lock()
Context.controlLock.Lock()
cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
data := config.versionCheckJSON
config.controlLock.Unlock()
Context.controlLock.Unlock()
if cached {
log.Tracef("Returning cached data")
@@ -93,7 +95,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
var resp *http.Response
for i := 0; i != 3; i++ {
log.Tracef("Downloading data from %s", versionCheckURL)
resp, err = config.client.Get(versionCheckURL)
resp, err = Context.client.Get(versionCheckURL)
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
// This case may happen while we're restarting DNS server
// https://github.com/AdguardTeam/AdGuardHome/issues/934
@@ -116,10 +118,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
config.controlLock.Lock()
Context.controlLock.Lock()
config.versionCheckLastTime = now
config.versionCheckJSON = body
config.controlLock.Unlock()
Context.controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body))
@@ -158,7 +160,7 @@ type updateInfo struct {
func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
var u updateInfo
workDir := config.ourWorkingDir
workDir := Context.workDir
versionJSON := make(map[string]interface{})
err := json.Unmarshal(jsonData, &versionJSON)
@@ -196,7 +198,7 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
binName = "AdGuardHome.exe"
}
u.curBinName = filepath.Join(workDir, binName)
if !fileExists(u.curBinName) {
if !util.FileExists(u.curBinName) {
return nil, fmt.Errorf("Executable file %s doesn't exist", u.curBinName)
}
u.bkpBinName = filepath.Join(u.backupDir, binName)
@@ -365,7 +367,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
// Download package file and save it to disk
func getPackageFile(u *updateInfo) error {
resp, err := config.client.Get(u.pkgURL)
resp, err := Context.client.Get(u.pkgURL)
if err != nil {
return fmt.Errorf("HTTP request failed: %s", err)
}
@@ -436,17 +438,17 @@ func doUpdate(u *updateInfo) error {
}
// ./README.md -> backup/README.md
err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true)
err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true)
if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
config.ourWorkingDir, u.backupDir, err)
Context.workDir, u.backupDir, err)
}
// update/[AdGuardHome/]README.md -> ./README.md
err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true)
err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true)
if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
u.updateDir, config.ourWorkingDir, err)
u.updateDir, Context.workDir, err)
}
log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName)
@@ -478,8 +480,7 @@ func finishUpdate(u *updateInfo) {
cleanupAlways()
if runtime.GOOS == "windows" {
if config.runningAsService {
if Context.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it

View File

@@ -8,9 +8,8 @@ import (
)
func TestDoUpdate(t *testing.T) {
config.DNS.Port = 0
config.ourWorkingDir = "..." // set absolute path
Context.workDir = "..." // set absolute path
newver := "v0.96"
data := `{
@@ -35,15 +34,15 @@ func TestDoUpdate(t *testing.T) {
u := updateInfo{
pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/" + newver + "/AdGuardHome_linux_amd64.tar.gz",
pkgName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
newVer: newver,
updateDir: config.ourWorkingDir + "/agh-update-" + newver,
backupDir: config.ourWorkingDir + "/agh-backup",
configName: config.ourWorkingDir + "/AdGuardHome.yaml",
updateConfigName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml",
curBinName: config.ourWorkingDir + "/AdGuardHome",
bkpBinName: config.ourWorkingDir + "/agh-backup/AdGuardHome",
newBinName: config.ourWorkingDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome",
updateDir: Context.workDir + "/agh-update-" + newver,
backupDir: Context.workDir + "/agh-backup",
configName: Context.workDir + "/AdGuardHome.yaml",
updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome.yaml",
curBinName: Context.workDir + "/AdGuardHome",
bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/AdGuardHome",
}
if uu.pkgURL != u.pkgURL ||

View File

@@ -25,7 +25,7 @@ func onConfigModified() {
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
func initDNSServer() error {
baseDir := config.getDataDir()
baseDir := Context.getDataDir()
err := os.MkdirAll(baseDir, 0755)
if err != nil {
@@ -71,8 +71,8 @@ func initDNSServer() error {
}
sessFilename := filepath.Join(baseDir, "sessions.db")
config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if config.auth == nil {
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
closeDNSServer()
return fmt.Errorf("Couldn't initialize Auth module")
}
@@ -294,9 +294,9 @@ func closeDNSServer() {
Context.queryLog = nil
}
if config.auth != nil {
config.auth.Close()
config.auth = nil
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
log.Debug("Closed all DNS modules")

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
)
@@ -401,7 +402,7 @@ func parseFilterContents(contents []byte) (int, string) {
// Count lines in the filter
for len(data) != 0 {
line := SplitNext(&data, '\n')
line := util.SplitNext(&data, '\n')
if len(line) == 0 {
continue
}
@@ -424,7 +425,7 @@ func parseFilterContents(contents []byte) (int, string) {
func (filter *filter) update() (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
resp, err := config.client.Get(filter.URL)
resp, err := Context.client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
@@ -538,7 +539,7 @@ func (filter *filter) unload() {
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(config.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
// LastTimeUpdated returns the time when the filter was last time updated

View File

@@ -10,7 +10,12 @@ import (
)
func TestFilters(t *testing.T) {
config.client = &http.Client{
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{}
Context.workDir = dir
Context.client = &http.Client{
Timeout: time.Minute * 5,
}
@@ -33,5 +38,5 @@ func TestFilters(t *testing.T) {
assert.True(t, err == nil)
f.unload()
os.Remove(f.Path())
_ = os.Remove(f.Path())
}

View File

@@ -1,380 +0,0 @@
package home
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
// ----------------------------------
// helper functions for HTTP handlers
// ----------------------------------
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
log.Debug("%s %v", r.Method, r.URL)
if r.Method != method {
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
if method == "POST" || method == "PUT" || method == "DELETE" {
config.controlLock.Lock()
defer config.controlLock.Unlock()
}
handler(w, r)
}
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("POST", handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("GET", handler)
}
// Bridge between http.Handler object and Go function
type httpHandler struct {
handler func(http.ResponseWriter, *http.Request)
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r)
}
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
h := httpHandler{}
h.handler = ensure(method, handler)
return &h
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := config.ourConfigFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename)
}
_, err := os.Stat(configfile)
if !os.IsNotExist(err) {
// do nothing, file exists
return false
}
return true
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !config.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
// it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if config.firstRun &&
!(strings.HasPrefix(r.URL.Path, "/install.") ||
strings.HasPrefix(r.URL.Path, "/__locales/") ||
r.URL.Path == "/favicon.png") {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
return
}
// enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// no port in host
host = r.Host
}
// construct new URL to redirect to
newURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r)
}
}
type postInstallHandlerStruct struct {
handler http.Handler
}
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
}
// ------------------
// network interfaces
// ------------------
type netInterface struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Flags string `json:"flags"`
}
// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
// invalid interface is a ppp interface or the one that doesn't allow broadcasts
func getValidNetInterfaces() ([]net.Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
}
netIfaces := []net.Interface{}
for i := range ifaces {
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
// this interface is ppp, we're not interested in this one
continue
}
iface := ifaces[i]
netIfaces = append(netIfaces, iface)
}
return netIfaces, nil
}
// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only
// we do not return link-local addresses here
func getValidNetInterfacesForWeb() ([]netInterface, error) {
ifaces, err := getValidNetInterfaces()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't get interfaces")
}
if len(ifaces) == 0 {
return nil, errors.New("couldn't find any legible interface")
}
var netInterfaces []netInterface
for _, iface := range ifaces {
addrs, e := iface.Addrs()
if e != nil {
return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name)
}
netIface := netInterface{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr.String(),
}
if iface.Flags != 0 {
netIface.Flags = iface.Flags.String()
}
// we don't want link-local addresses in json, so skip them
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
return nil, fmt.Errorf("SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
}
// ignore link-local
if ipnet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipnet.IP.String())
}
if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface)
}
}
return netInterfaces, nil
}
// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily
func checkPortAvailable(host string, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
_ = ln.Close()
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
time.Sleep(100 * time.Millisecond)
return nil
}
func checkPacketPortAvailable(host string, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
_ = ln.Close()
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
time.Sleep(100 * time.Millisecond)
return err
}
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
addrs, e := Context.dnsServer.Resolve(host)
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
if e != nil {
return nil, e
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
return con, err
}
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
// check if error is "address already in use"
func errorIsAddrInUse(err error) bool {
errOpError, ok := err.(*net.OpError)
if !ok {
return false
}
errSyscallError, ok := errOpError.Err.(*os.SyscallError)
if !ok {
return false
}
errErrno, ok := errSyscallError.Err.(syscall.Errno)
if !ok {
return false
}
if runtime.GOOS == "windows" {
const WSAEADDRINUSE = 10048
return errErrno == WSAEADDRINUSE
}
return errErrno == syscall.EADDRINUSE
}
// ---------------------
// general helpers
// ---------------------
// fileExists returns TRUE if file exists
func fileExists(fn string) bool {
_, err := os.Stat(fn)
if err != nil {
return false
}
return true
}
// runCommand runs shell command
func runCommand(command string, arguments ...string) (int, string, error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
if err != nil {
return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err)
}
return cmd.ProcessState.ExitCode(), string(out), nil
}
// ---------------------
// debug logging helpers
// ---------------------
func _Func() string {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
return path.Base(f.Name())
}
// SplitNext - split string by a byte and return the first chunk
// Whitespace is trimmed
func SplitNext(str *string, splitBy byte) string {
i := strings.IndexByte(*str, splitBy)
s := ""
if i != -1 {
s = (*str)[0:i]
*str = (*str)[i+1:]
} else {
s = *str
*str = ""
}
return strings.TrimSpace(s)
}

View File

@@ -1,33 +0,0 @@
package home
import (
"testing"
"github.com/AdguardTeam/golibs/log"
"github.com/stretchr/testify/assert"
)
func TestGetValidNetInterfacesForWeb(t *testing.T) {
ifaces, err := getValidNetInterfacesForWeb()
if err != nil {
t.Fatalf("Cannot get net interfaces: %s", err)
}
if len(ifaces) == 0 {
t.Fatalf("No net interfaces found")
}
for _, iface := range ifaces {
if len(iface.Addresses) == 0 {
t.Fatalf("No addresses found for %s", iface.Name)
}
log.Printf("%v", iface)
}
}
func TestSplitNext(t *testing.T) {
s := " a,b , c "
assert.True(t, SplitNext(&s, ',') == "a")
assert.True(t, SplitNext(&s, ',') == "b")
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
}

View File

@@ -20,6 +20,10 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/joomcode/errorx"
"github.com/AdguardTeam/AdGuardHome/isdelve"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
@@ -49,6 +53,9 @@ const versionCheckPeriod = time.Hour * 8
// Global context
type homeContext struct {
// Modules
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
@@ -57,8 +64,29 @@ type homeContext struct {
whois *Whois // WHOIS module
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
// Runtime properties
// --
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
disableUpdate bool // If set, don't check for updates
controlLock sync.Mutex
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
}
// getDataDir returns path to the directory where we store databases and filters
func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir)
}
// Context - a global context object
@@ -81,17 +109,38 @@ func Main(version string, channel string, armVer string) {
return
}
Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-Context.appSignalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
// run the protection
run(args)
}
// run initializes configuration and runs the AdGuard Home
// run is a blocking method and it won't exit until the service is stopped!
// run is a blocking method!
// nolint
func run(args options) {
// config file path can be overridden by command-line arguments:
if args.configFilename != "" {
config.ourConfigFilename = args.configFilename
Context.configFilename = args.configFilename
} else {
// Default config file name
Context.configFilename = "AdGuardHome.yaml"
}
// Init some of the Context fields right away
Context.transport = &http.Transport{
DialContext: customDialContext,
}
Context.client = &http.Client{
Timeout: time.Minute * 5,
Transport: Context.transport,
}
// configure working dir and config path
@@ -106,31 +155,22 @@ func run(args options) {
msg = msg + " v" + ARMVersion
}
log.Printf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH, ARMVersion)
log.Debug("Current working directory is %s", config.ourWorkingDir)
log.Debug("Current working directory is %s", Context.workDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
}
config.runningAsService = args.runningAsService
config.disableUpdate = args.disableUpdate
Context.runningAsService = args.runningAsService
Context.disableUpdate = args.disableUpdate
config.firstRun = detectFirstRun()
if config.firstRun {
Context.firstRun = detectFirstRun()
if Context.firstRun {
requireAdminRights()
}
config.appSignalChannel = make(chan os.Signal)
signal.Notify(config.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-config.appSignalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
initConfig()
initServices()
if !config.firstRun {
if !Context.firstRun {
// Do the upgrade if necessary
err := upgradeConfig()
if err != nil {
@@ -148,7 +188,7 @@ func run(args options) {
}
}
config.DHCP.WorkDir = config.ourWorkingDir
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
Context.dhcpServer = dhcpd.Create(config.DHCP)
@@ -157,7 +197,7 @@ func run(args options) {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
config.RlimitNoFile != 0 {
setRlimit(config.RlimitNoFile)
util.SetRlimit(config.RlimitNoFile)
}
// override bind host/port from the console
@@ -168,7 +208,7 @@ func run(args options) {
config.BindPort = args.bindPort
}
if !config.firstRun {
if !Context.firstRun {
// Save the updated config
err := config.write()
if err != nil {
@@ -193,7 +233,7 @@ func run(args options) {
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
config.pidFileName = args.pidFile
Context.pidFileName = args.pidFile
}
// Initialize and run the admin Web interface
@@ -204,7 +244,7 @@ func run(args options) {
registerControlHandlers()
// add handlers for /install paths, we only need them when we're not configured yet
if config.firstRun {
if Context.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerInstallHandlers()
@@ -291,7 +331,7 @@ func httpServerLoop() {
// Check if the current user has root (administrator) rights
// and if not, ask and try to run as root
func requireAdminRights() {
admin, _ := haveAdminRights()
admin, _ := util.HaveAdminRights()
if //noinspection ALL
admin || isdelve.Enabled {
return
@@ -331,7 +371,7 @@ func writePIDFile(fn string) bool {
return true
}
// initWorkingDir initializes the ourWorkingDir
// initWorkingDir initializes the workDir
// if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(args options) {
execPath, err := os.Executable()
@@ -341,9 +381,9 @@ func initWorkingDir(args options) {
if args.workDir != "" {
// If there is a custom config file, use it's directory as our working dir
config.ourWorkingDir = args.workDir
Context.workDir = args.workDir
} else {
config.ourWorkingDir = filepath.Dir(execPath)
Context.workDir = filepath.Dir(execPath)
}
}
@@ -376,12 +416,12 @@ func configureLogger(args options) {
if ls.LogFile == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := configureSyslog()
err := util.ConfigureSyslog(serviceName)
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile)
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile
}
@@ -420,8 +460,8 @@ func stopHTTPServer() {
// This function is called before application exits
func cleanupAlways() {
if len(config.pidFileName) != 0 {
_ = os.Remove(config.pidFileName)
if len(Context.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName)
}
log.Info("Stopped")
}
@@ -544,7 +584,7 @@ func printHTTPAddresses(proto string) {
}
} else if config.BindHost == "0.0.0.0" {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := getValidNetInterfacesForWeb()
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
@@ -561,3 +601,60 @@ func printHTTPAddresses(proto string) {
log.Printf("Go to %s://%s", proto, address)
}
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := Context.configFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(Context.workDir, Context.configFilename)
}
_, err := os.Stat(configfile)
if !os.IsNotExist(err) {
// do nothing, file exists
return false
}
return true
}
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
addrs, e := Context.dnsServer.Resolve(host)
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
if e != nil {
return nil, e
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
return con, err
}
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}

View File

@@ -107,6 +107,9 @@ schema_version: 5
// . Wait until the filters are downloaded
// . Stop and cleanup
func TestHome(t *testing.T) {
// Reinit context
Context = homeContext{}
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "AdGuardHome.yaml")
@@ -123,12 +126,12 @@ func TestHome(t *testing.T) {
var err error
var resp *http.Response
h := http.Client{}
for i := 0; i != 5; i++ {
for i := 0; i != 50; i++ {
resp, err = h.Get("http://127.0.0.1:3000/")
if err == nil && resp.StatusCode != 404 {
break
}
time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)
}
assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, 200, resp.StatusCode)
@@ -140,7 +143,7 @@ func TestHome(t *testing.T) {
// test DNS over UDP
r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
assert.Truef(t, err == nil, "%s", err)
assert.Nil(t, err)
haveIP := len(addrs) != 0
assert.True(t, haveIP)

View File

@@ -1,27 +0,0 @@
// +build freebsd
package home
import (
"os"
"syscall"
"github.com/AdguardTeam/golibs/log"
)
// Set user-specified limit of how many fd's we can use
// https://github.com/AdguardTeam/AdGuardHome/issues/659
func setRlimit(val uint) {
var rlim syscall.Rlimit
rlim.Max = int64(val)
rlim.Cur = int64(val)
err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlim)
if err != nil {
log.Error("Setrlimit() failed: %v", err)
}
}
// Check if the current user has root (administrator) rights
func haveAdminRights() (bool, error) {
return os.Getuid() == 0, nil
}

View File

@@ -1,27 +0,0 @@
// +build aix darwin dragonfly linux netbsd openbsd solaris
package home
import (
"os"
"syscall"
"github.com/AdguardTeam/golibs/log"
)
// Set user-specified limit of how many fd's we can use
// https://github.com/AdguardTeam/AdGuardHome/issues/659
func setRlimit(val uint) {
var rlim syscall.Rlimit
rlim.Max = uint64(val)
rlim.Cur = uint64(val)
err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlim)
if err != nil {
log.Error("Setrlimit() failed: %v", err)
}
}
// Check if the current user has root (administrator) rights
func haveAdminRights() (bool, error) {
return os.Getuid() == 0, nil
}

View File

@@ -1,28 +0,0 @@
package home
import "golang.org/x/sys/windows"
// Set user-specified limit of how many fd's we can use
func setRlimit(val uint) {
}
func haveAdminRights() (bool, error) {
var token windows.Token
h, _ := windows.GetCurrentProcess()
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
if err != nil {
return false, err
}
info := make([]byte, 4)
var returnedLen uint32
err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen)
token.Close()
if err != nil {
return false, err
}
if info[0] == 0 {
return false, nil
}
return true, nil
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"syscall"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/kardianos/service"
)
@@ -34,10 +35,10 @@ func (p *program) Start(s service.Service) error {
// Stop stops the program
func (p *program) Stop(s service.Service) error {
// Stop should not block. Return with a few seconds.
if config.appSignalChannel == nil {
if Context.appSignalChannel == nil {
os.Exit(0)
}
config.appSignalChannel <- syscall.SIGINT
Context.appSignalChannel <- syscall.SIGINT
return nil
}
@@ -229,7 +230,7 @@ func configureService(c *service.Config) {
// returns command code or error if any
func runInitdCommand(action string) (int, error) {
confPath := "/etc/init.d/" + serviceName
code, _, err := runCommand("sh", "-c", confPath+" "+action)
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
return code, err
}

View File

@@ -1,18 +0,0 @@
// +build !windows,!nacl,!plan9
package home
import (
"log"
"log/syslog"
)
// configureSyslog reroutes standard logger output to syslog
func configureSyslog() error {
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
if err != nil {
return err
}
log.SetOutput(w)
return nil
}

View File

@@ -1,39 +0,0 @@
package home
import (
"log"
"strings"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/eventlog"
)
type eventLogWriter struct {
el *eventlog.Log
}
// Write sends a log message to the Event Log.
func (w *eventLogWriter) Write(b []byte) (int, error) {
return len(b), w.el.Info(1, string(b))
}
func configureSyslog() error {
// Note that the eventlog src is the same as the service name
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
// Continue if we receive "registry key already exists" or if we get
// ERROR_ACCESS_DENIED so that we can log without administrative permissions
// for pre-existing eventlog sources.
if err := eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error); err != nil {
if !strings.Contains(err.Error(), "registry key already exists") && err != windows.ERROR_ACCESS_DENIED {
return err
}
}
el, err := eventlog.Open(serviceName)
if err != nil {
return err
}
log.SetOutput(&eventLogWriter{el: el})
return nil
}

View File

@@ -5,6 +5,8 @@ import (
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/crypto/bcrypt"
@@ -114,9 +116,9 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
// The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt")
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath)
@@ -135,9 +137,9 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
// coredns is now dns in config
// delete 'Corefile', since we don't use that anymore
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile")
coreFilePath := filepath.Join(Context.workDir, "Corefile")
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
err = os.Remove(coreFilePath)
@@ -159,7 +161,7 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
// Third schema upgrade:
// Bootstrap DNS becomes an array
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
// Let's read dns configuration from diskConfig
dnsConfig, ok := (*diskConfig)["dns"]
@@ -196,7 +198,7 @@ func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
// Add use_global_blocked_services=true setting for existing "clients" array
func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 4
@@ -233,7 +235,7 @@ func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
// password: "..."
// ...
func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 5
@@ -288,7 +290,7 @@ func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
// - 127.0.0.1
// - ...
func upgradeSchema5to6(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 6

View File

@@ -8,6 +8,8 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
@@ -61,7 +63,7 @@ func whoisParse(data string) map[string]string {
descr := ""
netname := ""
for len(data) != 0 {
ln := SplitNext(&data, '\n')
ln := util.SplitNext(&data, '\n')
if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' {
continue
}