Pull request: * all: move internal Go packages to internal/

Merge in DNS/adguard-home from 2234-move-to-internal to master

Squashed commit of the following:

commit d26a288cabeac86f9483fab307677b1027c78524
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Oct 30 12:44:18 2020 +0300

    * all: move internal Go packages to internal/

    Closes #2234.
This commit is contained in:
Ainar Garipov
2020-10-30 13:32:02 +03:00
parent df3fa595a2
commit ae8de95d89
125 changed files with 85 additions and 85 deletions

536
internal/home/auth.go Normal file
View File

@@ -0,0 +1,536 @@
package home
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"go.etcd.io/bbolt"
"golang.org/x/crypto/bcrypt"
)
const cookieTTL = 365 * 24 // in hours
const sessionCookieName = "agh_session"
type session struct {
userName string
expire uint32 // expiration time (in seconds)
}
/*
expire byte[4]
name_len byte[2]
name byte[]
*/
func (s *session) serialize() []byte {
var data []byte
data = make([]byte, 4+2+len(s.userName))
binary.BigEndian.PutUint32(data[0:4], s.expire)
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
copy(data[6:], []byte(s.userName))
return data
}
func (s *session) deserialize(data []byte) bool {
if len(data) < 4+2 {
return false
}
s.expire = binary.BigEndian.Uint32(data[0:4])
nameLen := binary.BigEndian.Uint16(data[4:6])
data = data[6:]
if len(data) < int(nameLen) {
return false
}
s.userName = string(data)
return true
}
// Auth - global object
type Auth struct {
db *bbolt.DB
sessions map[string]*session // session name -> session data
lock sync.Mutex
users []User
sessionTTL uint32 // in seconds
}
// User object
type User struct {
Name string `yaml:"name"`
PasswordHash string `yaml:"password"` // bcrypt hash
}
// InitAuth - create a global object
func InitAuth(dbFilename string, users []User, sessionTTL uint32) *Auth {
log.Info("Initializing auth module: %s", dbFilename)
a := Auth{}
a.sessionTTL = sessionTTL
a.sessions = make(map[string]*session)
rand.Seed(time.Now().UTC().Unix())
var err error
a.db, err = bbolt.Open(dbFilename, 0644, nil)
if err != nil {
log.Error("Auth: open DB: %s: %s", dbFilename, err)
if err.Error() == "invalid argument" {
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#limitations")
}
return nil
}
a.loadSessions()
a.users = users
log.Info("Auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
return &a
}
// Close - close module
func (a *Auth) Close() {
_ = a.db.Close()
}
func bucketName() []byte {
return []byte("sessions-2")
}
// load sessions from file, remove expired sessions
func (a *Auth) loadSessions() {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("Auth: bbolt.Begin: %s", err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
return
}
removed := 0
if tx.Bucket([]byte("sessions")) != nil {
_ = tx.DeleteBucket([]byte("sessions"))
removed = 1
}
now := uint32(time.Now().UTC().Unix())
forEach := func(k, v []byte) error {
s := session{}
if !s.deserialize(v) || s.expire <= now {
err = bkt.Delete(k)
if err != nil {
log.Error("Auth: bbolt.Delete: %s", err)
} else {
removed++
}
return nil
}
a.sessions[hex.EncodeToString(k)] = &s
return nil
}
_ = bkt.ForEach(forEach)
if removed != 0 {
err = tx.Commit()
if err != nil {
log.Error("bolt.Commit(): %s", err)
}
}
log.Debug("Auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
}
// store session data in file
func (a *Auth) addSession(data []byte, s *session) {
name := hex.EncodeToString(data)
a.lock.Lock()
a.sessions[name] = s
a.lock.Unlock()
if a.storeSession(data, s) {
log.Debug("Auth: created session %s: expire=%d", name, s.expire)
}
}
// store session data in file
func (a *Auth) storeSession(data []byte, s *session) bool {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("Auth: bbolt.Begin: %s", err)
return false
}
defer func() {
_ = tx.Rollback()
}()
bkt, err := tx.CreateBucketIfNotExists(bucketName())
if err != nil {
log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err)
return false
}
err = bkt.Put(data, s.serialize())
if err != nil {
log.Error("Auth: bbolt.Put: %s", err)
return false
}
err = tx.Commit()
if err != nil {
log.Error("Auth: bbolt.Commit: %s", err)
return false
}
return true
}
// remove session from file
func (a *Auth) removeSession(sess []byte) {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("Auth: bbolt.Begin: %s", err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
log.Error("Auth: bbolt.Bucket")
return
}
err = bkt.Delete(sess)
if err != nil {
log.Error("Auth: bbolt.Put: %s", err)
return
}
err = tx.Commit()
if err != nil {
log.Error("Auth: bbolt.Commit: %s", err)
return
}
log.Debug("Auth: removed session from DB")
}
// CheckSession - check if session is valid
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
func (a *Auth) CheckSession(sess string) int {
now := uint32(time.Now().UTC().Unix())
update := false
a.lock.Lock()
s, ok := a.sessions[sess]
if !ok {
a.lock.Unlock()
return -1
}
if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSession(key)
a.lock.Unlock()
return 1
}
newExpire := now + a.sessionTTL
if s.expire/(24*60*60) != newExpire/(24*60*60) {
// update expiration time once a day
update = true
s.expire = newExpire
}
a.lock.Unlock()
if update {
key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) {
log.Debug("Auth: updated session %s: expire=%d", sess, s.expire)
}
}
return 0
}
// RemoveSession - remove session
func (a *Auth) RemoveSession(sess string) {
key, _ := hex.DecodeString(sess)
a.lock.Lock()
delete(a.sessions, sess)
a.lock.Unlock()
a.removeSession(key)
}
type loginJSON struct {
Name string `json:"name"`
Password string `json:"password"`
}
func getSession(u *User) []byte {
// the developers don't currently believe that using a
// non-cryptographic RNG for the session hash salt is
// insecure
salt := rand.Uint32() //nolint:gosec
d := []byte(fmt.Sprintf("%d%s%s", salt, u.Name, u.PasswordHash))
hash := sha256.Sum256(d)
return hash[:]
}
func (a *Auth) httpCookie(req loginJSON) string {
u := a.UserFind(req.Name, req.Password)
if len(u.Name) == 0 {
return ""
}
sess := getSession(&u)
now := time.Now().UTC()
expire := now.Add(cookieTTL * time.Hour)
expstr := expire.Format(time.RFC1123)
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
expstr += "GMT"
s := session{}
s.userName = u.Name
s.expire = uint32(now.Unix()) + a.sessionTTL
a.addSession(sess, &s)
return fmt.Sprintf("%s=%s; Path=/; HttpOnly; Expires=%s",
sessionCookieName, hex.EncodeToString(sess), expstr)
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
req := loginJSON{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
cookie := Context.auth.httpCookie(req)
if len(cookie) == 0 {
log.Info("Auth: invalid user name or password: name='%s'", req.Name)
time.Sleep(1 * time.Second)
http.Error(w, "invalid user name or password", http.StatusBadRequest)
return
}
w.Header().Set("Set-Cookie", cookie)
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
returnOK(w)
}
func handleLogout(w http.ResponseWriter, r *http.Request) {
cookie := r.Header.Get("Cookie")
sess := parseCookie(cookie)
Context.auth.RemoveSession(sess)
w.Header().Set("Location", "/login.html")
s := fmt.Sprintf("%s=; Path=/; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
sessionCookieName)
w.Header().Set("Set-Cookie", s)
w.WriteHeader(http.StatusFound)
}
// RegisterAuthHandlers - register handlers
func RegisterAuthHandlers() {
http.Handle("/control/login", postInstallHandler(ensureHandler("POST", handleLogin)))
httpRegister("GET", "/control/logout", handleLogout)
}
func parseCookie(cookie string) string {
pairs := strings.Split(cookie, ";")
for _, pair := range pairs {
pair = strings.TrimSpace(pair)
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
continue
}
if kv[0] == sessionCookieName {
return kv[1]
}
}
return ""
}
// nolint(gocyclo)
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/login.html" {
// redirect to dashboard if already authenticated
authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
return
} else if r < 0 {
log.Debug("Auth: invalid cookie value: %s", cookie)
}
}
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
strings.HasPrefix(r.URL.Path, "/login.") {
// process as usual
// no additional auth requirements
} else if Context.auth != nil && Context.auth.AuthRequired() {
// redirect to login page if not authenticated
ok := false
cookie, err := r.Cookie(sessionCookieName)
if glProcessCookie(r) {
log.Debug("Auth: authentification was handled by GL-Inet submodule")
ok = true
} else if err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
ok = true
} else if r < 0 {
log.Debug("Auth: invalid cookie value: %s", cookie)
}
} else {
// there's no Cookie, check Basic authentication
user, pass, ok2 := r.BasicAuth()
if ok2 {
u := Context.auth.UserFind(user, pass)
if len(u.Name) != 0 {
ok = true
} else {
log.Info("Auth: invalid Basic Authorization value")
}
}
}
if !ok {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
if glProcessRedirect(w, r) {
log.Debug("Auth: redirected to login page by GL-Inet submodule")
} else {
w.Header().Set("Location", "/login.html")
w.WriteHeader(http.StatusFound)
}
} else {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
}
return
}
}
handler(w, r)
}
}
type authHandler struct {
handler http.Handler
}
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
optionalAuth(a.handler.ServeHTTP)(w, r)
}
func optionalAuthHandler(handler http.Handler) http.Handler {
return &authHandler{handler}
}
// UserAdd - add new user
func (a *Auth) UserAdd(u *User, password string) {
if len(password) == 0 {
return
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Error("bcrypt.GenerateFromPassword: %s", err)
return
}
u.PasswordHash = string(hash)
a.lock.Lock()
a.users = append(a.users, *u)
a.lock.Unlock()
log.Debug("Auth: added user: %s", u.Name)
}
// UserFind - find a user
func (a *Auth) UserFind(login string, password string) User {
a.lock.Lock()
defer a.lock.Unlock()
for _, u := range a.users {
if u.Name == login &&
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
return u
}
}
return User{}
}
// GetCurrentUser - get the current user
func (a *Auth) GetCurrentUser(r *http.Request) User {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// there's no Cookie, check Basic authentication
user, pass, ok := r.BasicAuth()
if ok {
u := Context.auth.UserFind(user, pass)
return u
}
return User{}
}
a.lock.Lock()
s, ok := a.sessions[cookie.Value]
if !ok {
a.lock.Unlock()
return User{}
}
for _, u := range a.users {
if u.Name == s.userName {
a.lock.Unlock()
return u
}
}
a.lock.Unlock()
return User{}
}
// GetUsers - get users
func (a *Auth) GetUsers() []User {
a.lock.Lock()
users := a.users
a.lock.Unlock()
return users
}
// AuthRequired - if authentication is required
func (a *Auth) AuthRequired() bool {
if GLMode {
return true
}
a.lock.Lock()
r := (len(a.users) != 0)
a.lock.Unlock()
return r
}

View File

@@ -0,0 +1,102 @@
package home
import (
"bytes"
"encoding/binary"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"unsafe"
"github.com/AdguardTeam/golibs/log"
)
// GLMode - enable GL-Inet compatibility mode
var GLMode bool
var glFilePrefix = "/tmp/gl_token_"
const glTokenTimeoutSeconds = 3600
const glCookieName = "Admin-Token"
func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool {
if !GLMode {
return false
}
// redirect to gl-inet login
host, _, _ := net.SplitHostPort(r.Host)
url := "http://" + host
log.Debug("Auth: redirecting to %s", url)
http.Redirect(w, r, url, http.StatusFound)
return true
}
func glProcessCookie(r *http.Request) bool {
if !GLMode {
return false
}
glCookie, glerr := r.Cookie(glCookieName)
if glerr != nil {
return false
}
log.Debug("Auth: GL cookie value: %s", glCookie.Value)
if glCheckToken(glCookie.Value) {
return true
}
log.Info("Auth: invalid GL cookie value: %s", glCookie)
return false
}
func glCheckToken(sess string) bool {
tokenName := glFilePrefix + sess
_, err := os.Stat(tokenName)
if err != nil {
log.Error("os.Stat: %s", err)
return false
}
tokenDate := glGetTokenDate(tokenName)
now := uint32(time.Now().UTC().Unix())
return now <= (tokenDate + glTokenTimeoutSeconds)
}
func archIsLittleEndian() bool {
var i int32 = 0x01020304
u := unsafe.Pointer(&i)
pb := (*byte)(u)
b := *pb
return (b == 0x04)
}
func glGetTokenDate(file string) uint32 {
f, err := os.Open(file)
if err != nil {
log.Error("os.Open: %s", err)
return 0
}
var dateToken uint32
bs, err := ioutil.ReadAll(f)
if err != nil {
log.Error("ioutil.ReadAll: %s", err)
return 0
}
buf := bytes.NewBuffer(bs)
if archIsLittleEndian() {
err := binary.Read(buf, binary.LittleEndian, &dateToken)
if err != nil {
log.Error("binary.Read: %s", err)
return 0
}
} else {
err := binary.Read(buf, binary.BigEndian, &dateToken)
if err != nil {
log.Error("binary.Read: %s", err)
return 0
}
}
return dateToken
}

View File

@@ -0,0 +1,43 @@
package home
import (
"encoding/binary"
"io/ioutil"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAuthGL(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
GLMode = true
glFilePrefix = dir + "/gl_token_"
tval := uint32(1)
data := make([]byte, 4)
if archIsLittleEndian() {
binary.LittleEndian.PutUint32(data, tval)
} else {
binary.BigEndian.PutUint32(data, tval)
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0644))
assert.False(t, glCheckToken("test"))
tval = uint32(time.Now().UTC().Unix() + 60)
data = make([]byte, 4)
if archIsLittleEndian() {
binary.LittleEndian.PutUint32(data, tval)
} else {
binary.BigEndian.PutUint32(data, tval)
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0644))
r, _ := http.NewRequest("GET", "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r))
GLMode = false
}

177
internal/home/auth_test.go Normal file
View File

@@ -0,0 +1,177 @@
package home
import (
"encoding/hex"
"net/http"
"net/url"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
func TestAuth(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "sessions.db")
users := []User{
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
a := InitAuth(fn, nil, 60)
s := session{}
user := User{Name: "name"}
a.UserAdd(&user, "password")
assert.True(t, a.CheckSession("notfound") == -1)
a.RemoveSession("notfound")
sess := getSession(&users[0])
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 1)
// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 0)
a.Close()
// load saved session
a = InitAuth(fn, users, 60)
// the session is still alive
assert.True(t, a.CheckSession(sessStr) == 0)
// reset our expiration time because CheckSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
u := a.UserFind("name", "password")
assert.True(t, len(u.Name) != 0)
time.Sleep(3 * time.Second)
// load and remove expired sessions
a = InitAuth(fn, users, 60)
assert.True(t, a.CheckSession(sessStr) == -1)
a.Close()
os.Remove(fn)
}
// implements http.ResponseWriter
type testResponseWriter struct {
hdr http.Header
statusCode int
}
func (w *testResponseWriter) Header() http.Header {
return w.hdr
}
func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func TestAuthHTTP(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "sessions.db")
users := []User{
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
Context.auth = InitAuth(fn, users, 60)
handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
}
handler2 := optionalAuth(handler)
w := testResponseWriter{}
w.hdr = make(http.Header)
r := http.Request{}
r.Header = make(http.Header)
r.Method = "GET"
// get / - we're redirected to login page
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.statusCode == http.StatusFound)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
// go to login page
loginURL := w.hdr.Get("Location")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
// perform login
cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.True(t, cookie != "")
// get /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
// get / with basic auth
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.URL = &url.URL{Path: "/"}
r.SetBasicAuth("name", "password")
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Authorization")
// get login page with a valid cookie - we're redirected to /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
r.Header.Del("Cookie")
// get login page with an invalid cookie
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", "bad")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
Context.auth.Close()
}

704
internal/home/clients.go Normal file
View File

@@ -0,0 +1,704 @@
package home
import (
"bytes"
"fmt"
"net"
"os/exec"
"runtime"
"sort"
"strings"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
const (
clientsUpdatePeriod = 10 * time.Minute
)
var webHandlersRegistered = false
// Client information
type Client struct {
IDs []string
Tags []string
Name string
UseOwnSettings bool // false: use global settings
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
UseOwnBlockedServices bool // false: use global settings
BlockedServices []string
Upstreams []string // list of upstream servers to be used for the client's requests
// Custom upstream config for this client
// nil: not yet initialized
// not nil, but empty: initialized, no good upstreams
// not nil, not empty: Upstreams ready to be used
upstreamConfig *proxy.UpstreamConfig
}
type clientSource uint
// Client sources
const (
// Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS
ClientSourceWHOIS clientSource = iota // from WHOIS
ClientSourceRDNS // from rDNS
ClientSourceDHCP // from DHCP
ClientSourceARP // from 'arp -a'
ClientSourceHostsFile // from /etc/hosts
)
// ClientHost information
type ClientHost struct {
Host string
Source clientSource
WhoisInfo [][]string // [[key,value], ...]
}
type clientsContainer struct {
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
allTags map[string]bool
// dhcpServer is used for looking up clients IP addresses by MAC addresses
dhcpServer *dhcpd.Server
// dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server
autoHosts *util.AutoHosts // get entries from system hosts-files
testing bool // if TRUE, this object is used for internal tests
}
// Init initializes clients container
// dhcpServer: optional
// Note: this function must be called only once
func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server, autoHosts *util.AutoHosts) {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipHost = make(map[string]*ClientHost)
clients.allTags = make(map[string]bool)
for _, t := range clientTags {
clients.allTags[t] = false
}
clients.dhcpServer = dhcpServer
clients.autoHosts = autoHosts
clients.addFromConfig(objects)
if !clients.testing {
clients.addFromDHCP()
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
clients.autoHosts.SetOnChanged(clients.onHostsChanged)
}
}
// Start - start the module
func (clients *clientsContainer) Start() {
if !clients.testing {
if !webHandlersRegistered {
webHandlersRegistered = true
clients.registerWebHandlers()
}
go clients.periodicUpdate()
}
}
// Reload - reload auto-clients
func (clients *clientsContainer) Reload() {
clients.addFromSystemARP()
}
type clientObject struct {
Name string `yaml:"name"`
Tags []string `yaml:"tags"`
IDs []string `yaml:"ids"`
UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"`
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
BlockedServices []string `yaml:"blocked_services"`
Upstreams []string `yaml:"upstreams"`
}
func (clients *clientsContainer) tagKnown(tag string) bool {
_, ok := clients.allTags[tag]
return ok
}
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, cy := range objects {
cli := Client{
Name: cy.Name,
IDs: cy.IDs,
UseOwnSettings: !cy.UseGlobalSettings,
FilteringEnabled: cy.FilteringEnabled,
ParentalEnabled: cy.ParentalEnabled,
SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
UseOwnBlockedServices: !cy.UseGlobalBlockedServices,
Upstreams: cy.Upstreams,
}
for _, s := range cy.BlockedServices {
if !dnsfilter.BlockedSvcKnown(s) {
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
continue
}
cli.BlockedServices = append(cli.BlockedServices, s)
}
for _, t := range cy.Tags {
if !clients.tagKnown(t) {
log.Debug("Clients: skipping unknown tag '%s'", t)
continue
}
cli.Tags = append(cli.Tags, t)
}
sort.Strings(cli.Tags)
_, err := clients.Add(cli)
if err != nil {
log.Tracef("clientAdd: %s", err)
}
}
}
// WriteDiskConfig - write configuration
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
clients.lock.Lock()
for _, cli := range clients.list {
cy := clientObject{
Name: cli.Name,
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchEnabled: cli.SafeSearchEnabled,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
}
cy.Tags = stringArrayDup(cli.Tags)
cy.IDs = stringArrayDup(cli.IDs)
cy.BlockedServices = stringArrayDup(cli.BlockedServices)
cy.Upstreams = stringArrayDup(cli.Upstreams)
*objects = append(*objects, cy)
}
clients.lock.Unlock()
}
func (clients *clientsContainer) periodicUpdate() {
for {
clients.Reload()
time.Sleep(clientsUpdatePeriod)
}
}
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
switch flags {
case dhcpd.LeaseChangedAdded,
dhcpd.LeaseChangedAddedStatic,
dhcpd.LeaseChangedRemovedStatic:
clients.addFromDHCP()
}
}
func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
if ok {
return true
}
ch, ok := clients.ipHost[ip]
if !ok {
return false
}
if source > ch.Source {
return false // we're going to overwrite this client's info with a stronger source
}
return true
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
if !ok {
return Client{}, false
}
c.IDs = stringArrayDup(c.IDs)
c.Tags = stringArrayDup(c.Tags)
c.BlockedServices = stringArrayDup(c.BlockedServices)
c.Upstreams = stringArrayDup(c.Upstreams)
return c, true
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
if !ok {
return nil
}
if len(c.Upstreams) == 0 {
return nil
}
if c.upstreamConfig == nil {
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err == nil {
c.upstreamConfig = &config
}
}
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return Client{}, false
}
c, ok := clients.idIndex[ip]
if ok {
return *c, true
}
for _, c = range clients.list {
for _, id := range c.IDs {
_, ipnet, err := net.ParseCIDR(id)
if err != nil {
continue
}
if ipnet.Contains(ipAddr) {
return *c, true
}
}
}
if clients.dhcpServer == nil {
return Client{}, false
}
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
if macFound == nil {
return Client{}, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(hwAddr, macFound) {
return *c, true
}
}
}
return Client{}, false
}
// FindAutoClient - search for an auto-client by IP
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return ClientHost{}, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
ch, ok := clients.ipHost[ip]
if ok {
return *ch, true
}
return ClientHost{}, false
}
// Check if Client object's fields are correct
func (clients *clientsContainer) check(c *Client) error {
if len(c.Name) == 0 {
return fmt.Errorf("invalid Name")
}
if len(c.IDs) == 0 {
return fmt.Errorf("ID required")
}
for i, id := range c.IDs {
ip := net.ParseIP(id)
if ip != nil {
c.IDs[i] = ip.String() // normalize IP address
continue
}
_, _, err := net.ParseCIDR(id)
if err == nil {
continue
}
_, err = net.ParseMAC(id)
if err == nil {
continue
}
return fmt.Errorf("invalid ID: %s", id)
}
for _, t := range c.Tags {
if !clients.tagKnown(t) {
return fmt.Errorf("invalid tag: %s", t)
}
}
sort.Strings(c.Tags)
err := dnsforward.ValidateUpstreams(c.Upstreams)
if err != nil {
return fmt.Errorf("invalid upstream servers: %s", err)
}
return nil
}
// Add a new client object
// Return true: success; false: client exists.
func (clients *clientsContainer) Add(c Client) (bool, error) {
e := clients.check(&c)
if e != nil {
return false, e
}
clients.lock.Lock()
defer clients.lock.Unlock()
// check Name index
_, ok := clients.list[c.Name]
if ok {
return false, nil
}
// check ID index
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
}
}
// update Name index
clients.list[c.Name] = &c
// update ID index
for _, id := range c.IDs {
clients.idIndex[id] = &c
}
log.Debug("Clients: added '%s': ID:%v [%d]", c.Name, c.IDs, len(clients.list))
return true, nil
}
// Del removes a client
func (clients *clientsContainer) Del(name string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.list[name]
if !ok {
return false
}
// update Name index
delete(clients.list, name)
// update ID index
for _, id := range c.IDs {
delete(clients.idIndex, id)
}
return true
}
// Return TRUE if arrays are equal
func arraysEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i != len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}
// Update a client
func (clients *clientsContainer) Update(name string, c Client) error {
err := clients.check(&c)
if err != nil {
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
old, ok := clients.list[name]
if !ok {
return fmt.Errorf("client not found")
}
// check Name index
if old.Name != c.Name {
_, ok = clients.list[c.Name]
if ok {
return fmt.Errorf("client already exists")
}
}
// check IP index
if !arraysEqual(old.IDs, c.IDs) {
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
if ok && c2 != old {
return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
}
}
// update ID index
for _, id := range old.IDs {
delete(clients.idIndex, id)
}
for _, id := range c.IDs {
clients.idIndex[id] = old
}
}
// update Name index
if old.Name != c.Name {
delete(clients.list, old.Name)
clients.list[c.Name] = old
}
// update upstreams cache
c.upstreamConfig = nil
*old = c
return nil
}
// SetWhoisInfo - associate WHOIS information with a client
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
if ok {
log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip)
return
}
ch, ok := clients.ipHost[ip]
if ok {
ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
return
}
// Create a ClientHost implicitly so that we don't do this check again
ch = &ClientHost{
Source: ClientSourceWHOIS,
}
ch.WhoisInfo = info
clients.ipHost[ip] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
}
// AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock()
b, e := clients.addHost(ip, host, source)
clients.lock.Unlock()
return b, e
}
func (clients *clientsContainer) addHost(ip, host string, source clientSource) (bool, error) {
// check auto-clients index
ch, ok := clients.ipHost[ip]
if ok && ch.Source > source {
return false, nil
} else if ok {
ch.Source = source
} else {
ch = &ClientHost{
Host: host,
Source: source,
}
clients.ipHost[ip] = ch
}
log.Debug("Clients: added '%s' -> '%s' [%d]", ip, host, len(clients.ipHost))
return true, nil
}
// Remove all entries that match the specified source
func (clients *clientsContainer) rmHosts(source clientSource) int {
n := 0
for k, v := range clients.ipHost {
if v.Source == source {
delete(clients.ipHost, k)
n++
}
}
log.Debug("Clients: removed %d client aliases", n)
return n
}
// Fill clients array from system hosts-file
func (clients *clientsContainer) addFromHostsFile() {
hosts := clients.autoHosts.List()
clients.lock.Lock()
defer clients.lock.Unlock()
_ = clients.rmHosts(ClientSourceHostsFile)
n := 0
for ip, name := range hosts {
ok, err := clients.addHost(ip, name, ClientSourceHostsFile)
if err != nil {
log.Debug("Clients: %s", err)
}
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from system hosts-file", n)
}
// Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is:
// HOST (IP) at MAC on IFACE
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
return
}
cmd := exec.Command("arp", "-a")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
data, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Debug("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
_ = clients.rmHosts(ClientSourceARP)
n := 0
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
open := strings.Index(ln, " (")
close := strings.Index(ln, ") ")
if open == -1 || close == -1 || open >= close {
continue
}
host := ln[:open]
ip := ln[open+2 : close]
if utils.IsValidHostname(host) != nil || net.ParseIP(ip) == nil {
continue
}
ok, e := clients.addHost(ip, host, ClientSourceARP)
if e != nil {
log.Tracef("%s", e)
}
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
}
// Add clients from DHCP that have non-empty Hostname property
func (clients *clientsContainer) addFromDHCP() {
if clients.dhcpServer == nil {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
_ = clients.rmHosts(ClientSourceDHCP)
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if len(l.Hostname) == 0 {
continue
}
ok, _ := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from DHCP", n)
}

View File

@@ -0,0 +1,296 @@
package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
)
type clientJSON struct {
IDs []string `json:"ids"`
Tags []string `json:"tags"`
Name string `json:"name"`
UseGlobalSettings bool `json:"use_global_settings"`
FilteringEnabled bool `json:"filtering_enabled"`
ParentalEnabled bool `json:"parental_enabled"`
SafeSearchEnabled bool `json:"safesearch_enabled"`
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
BlockedServices []string `json:"blocked_services"`
Upstreams []string `json:"upstreams"`
WhoisInfo map[string]interface{} `json:"whois_info"`
// Disallowed - if true -- client's IP is not disallowed
// Otherwise, it is blocked.
Disallowed bool `json:"disallowed"`
// DisallowedRule - the rule due to which the client is disallowed
// If Disallowed is true, and this string is empty - it means that the client IP
// is disallowed by the "allowed IP list", i.e. it is not included in allowed.
DisallowedRule string `json:"disallowed_rule"`
}
type clientHostJSON struct {
IP string `json:"ip"`
Name string `json:"name"`
Source string `json:"source"`
WhoisInfo map[string]interface{} `json:"whois_info"`
}
type clientListJSON struct {
Clients []clientJSON `json:"clients"`
AutoClients []clientHostJSON `json:"auto_clients"`
Tags []string `json:"supported_tags"`
}
// respond with information about configured clients
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) {
data := clientListJSON{}
clients.lock.Lock()
for _, c := range clients.list {
cj := clientToJSON(c)
data.Clients = append(data.Clients, cj)
}
for ip, ch := range clients.ipHost {
cj := clientHostJSON{
IP: ip,
Name: ch.Host,
}
cj.Source = "etc/hosts"
switch ch.Source {
case ClientSourceDHCP:
cj.Source = "DHCP"
case ClientSourceRDNS:
cj.Source = "rDNS"
case ClientSourceARP:
cj.Source = "ARP"
case ClientSourceWHOIS:
cj.Source = "WHOIS"
}
cj.WhoisInfo = make(map[string]interface{})
for _, wi := range ch.WhoisInfo {
cj.WhoisInfo[wi[0]] = wi[1]
}
data.AutoClients = append(data.AutoClients, cj)
}
clients.lock.Unlock()
data.Tags = clientTags
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
return
}
}
// Convert JSON object to Client object
func jsonToClient(cj clientJSON) (*Client, error) {
c := Client{
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeSearchEnabled: cj.SafeSearchEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
}
return &c, nil
}
// Convert Client object to JSON
func clientToJSON(c *Client) clientJSON {
cj := clientJSON{
Name: c.Name,
IDs: c.IDs,
Tags: c.Tags,
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
ParentalEnabled: c.ParentalEnabled,
SafeSearchEnabled: c.SafeSearchEnabled,
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
BlockedServices: c.BlockedServices,
Upstreams: c.Upstreams,
}
return cj
}
// Convert ClientHost object to JSON
func clientHostToJSON(ip string, ch ClientHost) clientJSON {
cj := clientJSON{
Name: ch.Host,
IDs: []string{ip},
}
cj.WhoisInfo = make(map[string]interface{})
for _, wi := range ch.WhoisInfo {
cj.WhoisInfo[wi[0]] = wi[1]
}
return cj
}
// Add a new client
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
c, err := jsonToClient(cj)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clients.Add(*c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Client already exists")
return
}
onConfigModified()
}
// Remove client
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
if err != nil || len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if !clients.Del(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found")
return
}
onConfigModified()
}
type updateJSON struct {
Name string `json:"name"`
Data clientJSON `json:"data"`
}
// Update client's properties
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
var dj updateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request")
return
}
c, err := jsonToClient(dj.Data)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
err = clients.Update(dj.Name, *c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
onConfigModified()
}
// Get the list of clients by IP address list
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]interface{}{}
for i := 0; ; i++ {
ip := q.Get(fmt.Sprintf("ip%d", i))
if len(ip) == 0 {
break
}
el := map[string]interface{}{}
c, ok := clients.Find(ip)
if !ok {
ch, ok := clients.FindAutoClient(ip)
if !ok {
continue // a client with this IP isn't found
}
cj := clientHostToJSON(ip, ch)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
el[ip] = cj
} else {
cj := clientToJSON(&c)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
el[ip] = cj
}
data = append(data, el)
}
js, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(js)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
}
}
// RegisterClientsHandlers registers HTTP handlers
func (clients *clientsContainer) registerWebHandlers() {
httpRegister("GET", "/control/clients", clients.handleGetClients)
httpRegister("POST", "/control/clients/add", clients.handleAddClient)
httpRegister("POST", "/control/clients/delete", clients.handleDelClient)
httpRegister("POST", "/control/clients/update", clients.handleUpdateClient)
httpRegister("GET", "/control/clients/find", clients.handleFindClient)
}

View File

@@ -0,0 +1,27 @@
package home
var clientTags = []string{
"device_audio",
"device_camera",
"device_gameconsole",
"device_laptop",
"device_nas", // Network-attached Storage
"device_other",
"device_pc",
"device_phone",
"device_printer",
"device_securityalarm",
"device_tablet",
"device_tv",
"os_android",
"os_ios",
"os_linux",
"os_macos",
"os_other",
"os_windows",
"user_admin",
"user_child",
"user_regular",
}

View File

@@ -0,0 +1,266 @@
package home
import (
"net"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/stretchr/testify/assert"
)
func TestClients(t *testing.T) {
var c Client
var e error
var b bool
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// add
c = Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
}
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("Add #1")
}
// add #2
c = Client{
IDs: []string{"2.2.2.2"},
Name: "client2",
}
b, e = clients.Add(c)
if !b || e != nil {
t.Fatalf("Add #2")
}
c, b = clients.Find("1.1.1.1")
assert.True(t, b && c.Name == "client1")
c, b = clients.Find("1:2:3::4")
assert.True(t, b && c.Name == "client1")
c, b = clients.Find("2.2.2.2")
assert.True(t, b && c.Name == "client2")
// failed add - name in use
c = Client{
IDs: []string{"1.2.3.5"},
Name: "client1",
}
b, _ = clients.Add(c)
if b {
t.Fatalf("Add - name in use")
}
// failed add - ip in use
c = Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
b, e = clients.Add(c)
if b || e == nil {
t.Fatalf("Add - ip in use")
}
// get
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
// failed update - no such name
c.IDs = []string{"1.2.3.0"}
c.Name = "client3"
if clients.Update("client3", c) == nil {
t.Fatalf("Update")
}
// failed update - name in use
c.IDs = []string{"1.2.3.0"}
c.Name = "client2"
if clients.Update("client1", c) == nil {
t.Fatalf("Update - name in use")
}
// failed update - ip in use
c.IDs = []string{"2.2.2.2"}
c.Name = "client1"
if clients.Update("client1", c) == nil {
t.Fatalf("Update - ip in use")
}
// update
c.IDs = []string{"1.1.1.2"}
c.Name = "client1"
if clients.Update("client1", c) != nil {
t.Fatalf("Update")
}
// get after update
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
// update - rename
c.IDs = []string{"1.1.1.2"}
c.Name = "client1-renamed"
c.UseOwnSettings = true
assert.True(t, clients.Update("client1", c) == nil)
c = Client{}
c, b = clients.Find("1.1.1.2")
assert.True(t, b && c.Name == "client1-renamed" && c.IDs[0] == "1.1.1.2" && c.UseOwnSettings)
assert.True(t, clients.list["client1"] == nil)
// failed remove - no such name
if clients.Del("client3") {
t.Fatalf("Del - no such name")
}
// remove
assert.True(t, !(!clients.Del("client1-renamed") || clients.Exists("1.1.1.2", ClientSourceHostsFile)))
// add host client
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost")
}
// failed add - ip exists
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
if b || e != nil {
t.Fatalf("clientAddHost - ip exists")
}
// overwrite with new data
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data")
}
// overwrite with new data (higher priority)
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
if !b || e != nil {
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
}
// get
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
}
func TestClientsWhois(t *testing.T) {
var c Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val")
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
// Check that we cannot set whois info on a manually-added client
c = Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
_ = clients.Del("client1")
}
func TestClientsAddExisting(t *testing.T) {
var c Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// some test variables
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
testIP := "1.2.3.4"
// add a client
c = Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
}
ok, err := clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// add an auto-client with the same IP - it's allowed
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
assert.True(t, ok)
assert.Nil(t, err)
// now some more complicated stuff
// first, init a DHCP server with a single static lease
config := dhcpd.ServerConfig{
DBFilePath: "leases.db",
}
defer func() { _ = os.Remove("leases.db") }()
clients.dhcpServer = dhcpd.Create(config)
err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{
HWAddr: mac,
IP: net.ParseIP(testIP).To4(),
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
assert.Nil(t, err)
// add a new client with the same IP as for a client with MAC
c = Client{
IDs: []string{testIP},
Name: "client2",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// add a new client with the IP from the client1's IP range
c = Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
assert.Nil(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
}

295
internal/home/config.go Normal file
View File

@@ -0,0 +1,295 @@
package home
import (
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2"
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// logSettings
type logSettings struct {
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
// Raw file data to avoid re-reading of configuration file
// It's reset after config is parsed
fileData []byte
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
Users []User `yaml:"users"` // Users that can access HTTP server
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
Language string `yaml:"language"` // two-letter ISO 639-1 language code
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
// TTL for a web session (in hours)
// An active session is automatically refreshed once a day.
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
// Note: this array is filled only before file read/write and then it's cleared
Clients []clientObject `yaml:"clients"`
logSettings `yaml:",inline"`
sync.RWMutex `yaml:"-"`
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
}
// field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct {
BindHost string `yaml:"bind_host"`
Port int `yaml:"port"`
// time interval for statistics (in days)
StatsInterval uint32 `yaml:"statistics_interval"`
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"` // if true, query log will be written to a file
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
dnsforward.FilteringConfig `yaml:",inline"`
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf dnsfilter.Config `yaml:",inline"`
}
type tlsConfigSettings struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
// Allow DOH queries via unencrypted HTTP (e.g. for reverse proxying)
AllowUnencryptedDOH bool `yaml:"allow_unencrypted_doh" json:"allow_unencrypted_doh"`
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
}
// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{
BindPort: 3000,
BindHost: "0.0.0.0",
DNS: dnsConfig{
BindHost: "0.0.0.0",
Port: 53,
StatsInterval: 1,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of dnsfilter features
BlockingMode: "default", // mode how to answer filtered requests
BlockedResponseTTL: 10, // in seconds
Ratelimit: 20,
RefuseAny: true,
AllServers: false,
},
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
},
TLS: tlsConfigSettings{
PortHTTPS: 443,
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
PortDNSOverQUIC: 784,
},
logSettings: logSettings{
LogCompress: false,
LogLocalTime: false,
LogMaxBackups: 0,
LogMaxSize: 100,
LogMaxAge: 3,
},
SchemaVersion: currentSchemaVersion,
}
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
config.DNS.QueryLogFileEnabled = true
config.DNS.QueryLogInterval = 90
config.DNS.QueryLogMemSize = 1000
config.DNS.CacheSize = 4 * 1024 * 1024
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.DHCP.Conf4.LeaseDuration = 86400
config.DHCP.Conf4.ICMPTimeout = 1000
config.DHCP.Conf6.LeaseDuration = 86400
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
if err != nil {
if !os.IsNotExist(err) {
log.Error("unexpected error while config file path evaluation: %s", err)
}
configFile = Context.configFilename
}
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(Context.workDir, configFile)
}
return configFile
}
// getLogSettings reads logging settings from the config file.
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
func getLogSettings() logSettings {
l := logSettings{}
yamlFile, err := readConfigFile()
if err != nil {
return l
}
err = yaml.Unmarshal(yamlFile, &l)
if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err)
}
return l
}
// parseConfig loads configuration from the YAML file
func parseConfig() error {
configFile := config.getConfigFilename()
log.Debug("Reading config file: %s", configFile)
yamlFile, err := readConfigFile()
if err != nil {
return err
}
config.fileData = nil
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
log.Error("Couldn't parse config file: %s", err)
return err
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
}
return nil
}
// readConfigFile reads config file contents if it exists
func readConfigFile() ([]byte, error) {
if len(config.fileData) != 0 {
return config.fileData, nil
}
configFile := config.getConfigFilename()
d, err := ioutil.ReadFile(configFile)
if err != nil {
log.Error("Couldn't read config file %s: %s", configFile, err)
return nil, err
}
return d, nil
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
Context.clients.WriteDiskConfig(&config.Clients)
if Context.auth != nil {
config.Users = Context.auth.GetUsers()
}
if Context.tls != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf
}
if Context.stats != nil {
sdc := stats.DiskConfig{}
Context.stats.WriteDiskConfig(&sdc)
config.DNS.StatsInterval = sdc.Interval
}
if Context.queryLog != nil {
dc := querylog.Config{}
Context.queryLog.WriteDiskConfig(&dc)
config.DNS.QueryLogEnabled = dc.Enabled
config.DNS.QueryLogFileEnabled = dc.FileEnabled
config.DNS.QueryLogInterval = dc.Interval
config.DNS.QueryLogMemSize = dc.MemSize
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
}
if Context.dnsFilter != nil {
c := dnsfilter.Config{}
Context.dnsFilter.WriteDiskConfig(&c)
config.DNS.DnsfilterConf = c
}
if Context.dnsServer != nil {
c := dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(&c)
config.DNS.FilteringConfig = c
}
if Context.dhcpServer != nil {
c := dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(&c)
config.DHCP = c
}
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)
config.Clients = nil
if err != nil {
log.Error("Couldn't generate YAML file: %s", err)
return err
}
err = file.SafeWrite(configFile, yamlText)
if err != nil {
log.Error("Couldn't save YAML config: %s", err)
return err
}
return nil
}

234
internal/home/control.go Normal file
View File

@@ -0,0 +1,234 @@
package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
)
// ----------------
// helper functions
// ----------------
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info(text)
http.Error(w, text, code)
}
// ---------------
// dns run control
// ---------------
func addDNSAddress(dnsAddresses *[]string, addr string) {
if config.DNS.Port != 53 {
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
}
*dnsAddresses = append(*dnsAddresses, addr)
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
c := dnsforward.FilteringConfig{}
if Context.dnsServer != nil {
Context.dnsServer.WriteDiskConfig(&c)
}
data := map[string]interface{}{
"dns_addresses": getDNSAddresses(),
"http_port": config.BindPort,
"dns_port": config.DNS.Port,
"running": isRunning(),
"version": versionString,
"language": config.Language,
"protection_enabled": c.ProtectionEnabled,
}
data["dhcp_available"] = (Context.dhcpServer != nil)
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
type profileJSON struct {
Name string `json:"name"`
}
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{}
u := Context.auth.GetCurrentUser(r)
pj.Name = u.Name
data, err := json.Marshal(pj)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
}
_, _ = w.Write(data)
}
// ------------------------
// registration of handlers
// ------------------------
func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
// No auth is necessary for DOH/DOT configurations
http.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoh))
http.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDot))
RegisterAuthHandlers()
}
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
if len(method) == 0 {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
http.HandleFunc(url, postInstall(handler))
return
}
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
}
// ----------------------------------
// helper functions for HTTP handlers
// ----------------------------------
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
log.Debug("%s %v", r.Method, r.URL)
if r.Method != method {
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
if method == "POST" || method == "PUT" || method == "DELETE" {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
}
handler(w, r)
}
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("POST", handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("GET", handler)
}
// Bridge between http.Handler object and Go function
type httpHandler struct {
handler func(http.ResponseWriter, *http.Request)
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r)
}
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
h := httpHandler{}
h.handler = ensure(method, handler)
return &h
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !Context.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
// it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") &&
!strings.HasPrefix(r.URL.Path, "/assets/") {
http.Redirect(w, r, "/install.html", http.StatusFound)
return
}
// enforce https?
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// no port in host
host = r.Host
}
// construct new URL to redirect to
newURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r)
}
}
type postInstallHandlerStruct struct {
handler http.Handler
}
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
}

View File

@@ -0,0 +1,391 @@
package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// isValidURL - return TRUE if URL or file path is valid
func isValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
}
url, err := url.ParseRequestURI(rawurl)
if err != nil {
return false //Couldn't even parse the rawurl
}
if len(url.Scheme) == 0 {
return false //No Scheme found
}
return true
}
type filterAddJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
if !isValidURL(fj.URL) {
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
return
}
// Check for duplicates
if filterExists(fj.URL) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
// Set necessary properties
filt := filter{
Enabled: true,
URL: fj.URL,
Name: fj.Name,
white: fj.Whitelist,
}
filt.ID = assignUniqueFilterID()
// Download the filter contents
ok, err := f.update(&filt)
if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
onConfigModified()
enableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
// go through each element and delete if url matches
config.Lock()
newFilters := []filter{}
filters := &config.Filters
if req.Whitelist {
filters = &config.WhitelistFilters
}
for _, filter := range *filters {
if filter.URL != req.URL {
newFilters = append(newFilters, filter)
} else {
err := os.Rename(filter.Path(), filter.Path()+".old")
if err != nil {
log.Error("os.Rename: %s: %s", filter.Path(), err)
}
}
}
// Update the configuration after removing filter files
*filters = newFilters
config.Unlock()
onConfigModified()
enableFilters(true)
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
}
type filterURLJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type filterURLReq struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
Data filterURLJSON `json:"data"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !isValidURL(fj.Data.URL) {
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
return
}
filt := filter{
Enabled: fj.Data.Enabled,
Name: fj.Data.Name,
URL: fj.Data.URL,
}
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
return
}
if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest)
return
}
onConfigModified()
restart := false
if (status & statusEnabledChanged) != 0 {
// we must add or remove filter rules
restart = true
}
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := FilterRefreshBlocklists
if fj.Whitelist {
flags = FilterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves
restart = false
if nUpdated == 0 {
restart = true
}
}
if restart {
enableFilters(true)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
config.UserRules = strings.Split(string(body), "\n")
onConfigModified()
enableFilters(true)
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct {
White bool `json:"whitelist"`
}
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
Context.controlLock.Unlock()
flags := FilterRefreshBlocklists
if req.White {
flags = FilterRefreshAllowlists
}
resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
Context.controlLock.Lock()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
}
type filteringConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
}
func filterToJSON(f filter) filterJSON {
fj := filterJSON{
ID: f.ID,
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RulesCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{}
config.RLock()
resp.Enabled = config.DNS.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours
for _, f := range config.Filters {
fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj)
}
for _, f := range config.WhitelistFilters {
fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
resp.UserRules = config.UserRules
config.RUnlock()
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval")
return
}
config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval
onConfigModified()
enableFilters(true)
}
type checkHostResp struct {
Reason string `json:"reason"`
FilterID int64 `json:"filter_id"`
Rule string `json:"rule"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.FilterID = result.FilterID
resp.Rule = result.Rule
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

View File

@@ -0,0 +1,372 @@
package home
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/golibs/log"
)
type firstRunData struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]interface{} `json:"interfaces"`
}
type netInterfaceJSON struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Flags string `json:"flags"`
}
// Get initial installation settings
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := firstRunData{}
data.WebPort = 80
data.DNSPort = 53
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
data.Interfaces = make(map[string]interface{})
for _, iface := range ifaces {
ifaceJSON := netInterfaceJSON{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr,
Addresses: iface.Addresses,
Flags: iface.Flags,
}
data.Interfaces[iface.Name] = ifaceJSON
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
return
}
}
type checkConfigReqEnt struct {
Port int `json:"port"`
IP string `json:"ip"`
Autofix bool `json:"autofix"`
}
type checkConfigReq struct {
Web checkConfigReqEnt `json:"web"`
DNS checkConfigReqEnt `json:"dns"`
SetStaticIP bool `json:"set_static_ip"`
}
type checkConfigRespEnt struct {
Status string `json:"status"`
CanAutofix bool `json:"can_autofix"`
}
type staticIPJSON struct {
Static string `json:"static"`
IP string `json:"ip"`
Error string `json:"error"`
}
type checkConfigResp struct {
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
StaticIP staticIPJSON `json:"static_ip"`
}
// Check if ports are available, respond with results
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReq{}
respData := checkConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
if err != nil {
respData.Web.Status = fmt.Sprintf("%v", err)
}
}
if reqData.DNS.Port != 0 {
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
if util.ErrorIsAddrInUse(err) {
canAutofix := checkDNSStubListener()
if canAutofix && reqData.DNS.Autofix {
err = disableDNSStubListener()
if err != nil {
log.Error("Couldn't disable DNSStubListener: %s", err)
}
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
canAutofix = false
}
respData.DNS.CanAutofix = canAutofix
}
if err == nil {
err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
}
if err != nil {
respData.DNS.Status = fmt.Sprintf("%v", err)
} else if reqData.DNS.IP != "0.0.0.0" {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
}
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(respData)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
return
}
}
// handleStaticIP - handles static IP request
// It either checks if we have a static IP
// Or if set=true, it tries to set it
func handleStaticIP(ip string, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := util.GetInterfaceByIP(ip)
resp.Static = "no"
if len(interfaceName) == 0 {
resp.Static = "error"
resp.Error = fmt.Sprintf("Couldn't find network interface by IP %s", ip)
return resp
}
if set {
// Try to set static IP for the specified interface
err := dhcpd.SetStaticIP(interfaceName)
if err != nil {
resp.Static = "error"
resp.Error = err.Error()
return resp
}
}
// Fallthrough here even if we set static IP
// Check if we have a static IP and return the details
isStaticIP, err := dhcpd.HasStaticIP(interfaceName)
if err != nil {
resp.Static = "error"
resp.Error = err.Error()
} else {
if isStaticIP {
resp.Static = "yes"
}
resp.IP = util.GetSubnet(interfaceName)
}
return resp
}
// Check if DNSStubListener is active
func checkDNSStubListener() bool {
if runtime.GOOS != "linux" {
return false
}
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Info("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
return true
}
const resolvedConfPath = "/etc/systemd/resolved.conf.d/adguardhome.conf"
const resolvedConfData = `[Resolve]
DNS=127.0.0.1
DNSStubListener=no
`
const resolvConfPath = "/etc/resolv.conf"
// Deactivate DNSStubListener
func disableDNSStubListener() error {
dir := filepath.Dir(resolvedConfPath)
err := os.MkdirAll(dir, 0755)
if err != nil {
return fmt.Errorf("os.MkdirAll: %s: %s", dir, err)
}
err = ioutil.WriteFile(resolvedConfPath, []byte(resolvedConfData), 0644)
if err != nil {
return fmt.Errorf("ioutil.WriteFile: %s: %s", resolvedConfPath, err)
}
_ = os.Rename(resolvConfPath, resolvConfPath+".backup")
err = os.Symlink("/run/systemd/resolve/resolv.conf", resolvConfPath)
if err != nil {
_ = os.Remove(resolvedConfPath) // remove the file we've just created
return fmt.Errorf("os.Symlink: %s: %s", resolvConfPath, err)
}
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err = cmd.Output()
if err != nil {
return err
}
if cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("process %s exited with an error: %d",
cmd.Path, cmd.ProcessState.ExitCode())
}
return nil
}
type applyConfigReqEnt struct {
IP string `json:"ip"`
Port int `json:"port"`
}
type applyConfigReq struct {
Web applyConfigReqEnt `json:"web"`
DNS applyConfigReqEnt `json:"dns"`
Username string `json:"username"`
Password string `json:"password"`
}
// Copy installation parameters between two configuration objects
func copyInstallSettings(dst *configuration, src *configuration) {
dst.BindHost = src.BindHost
dst.BindPort = src.BindPort
dst.DNS.BindHost = src.DNS.BindHost
dst.DNS.Port = src.DNS.Port
}
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
newSettings := applyConfigReq{}
err := json.NewDecoder(r.Body).Decode(&newSettings)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'configure' JSON: %s", err)
return
}
if newSettings.Web.Port == 0 || newSettings.DNS.Port == 0 {
httpError(w, http.StatusBadRequest, "port value can't be 0")
return
}
restartHTTP := true
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
// no need to rebind
restartHTTP = false
}
// validate that hosts and ports are bindable
if restartHTTP {
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
return
}
}
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
var curConfig configuration
copyInstallSettings(&curConfig, &config)
Context.firstRun = false
config.BindHost = newSettings.Web.IP
config.BindPort = newSettings.Web.Port
config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port
err = StartMods()
if err != nil {
Context.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
u := User{}
u.Name = newSettings.Username
Context.auth.UserAdd(&u, newSettings.Password)
err = config.write()
if err != nil {
Context.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return
}
web.conf.firstRun = false
web.conf.BindHost = newSettings.Web.IP
web.conf.BindPort = newSettings.Web.Port
registerControlHandlers()
returnOK(w)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP {
go func() {
_ = Context.web.httpServer.Shutdown(context.TODO())
}()
}
}
func (web *Web) registerInstallHandlers() {
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
}

View File

@@ -0,0 +1,77 @@
package home
import (
"testing"
"time"
)
/* Tests performed:
. Bad certificate
. Bad private key
. Valid certificate & private key */
func TestValidateCertificates(t *testing.T) {
var data tlsConfigStatus
// bad cert
data = validateCertificates("bad cert", "", "")
if !(data.WarningValidation != "" &&
!data.ValidCert &&
!data.ValidChain) {
t.Fatalf("bad cert: validateCertificates(): %v", data)
}
// bad priv key
data = validateCertificates("", "bad priv key", "")
if !(data.WarningValidation != "" &&
!data.ValidKey) {
t.Fatalf("bad priv key: validateCertificates(): %v", data)
}
// valid cert & priv key
CertificateChain := `-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----`
PrivateKey := `-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----`
data = validateCertificates(CertificateChain, PrivateKey, "")
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
if !(data.WarningValidation != "" /* self signed */ &&
data.ValidCert &&
!data.ValidChain &&
data.ValidKey &&
data.KeyType == "RSA" &&
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.NotBefore == notBefore &&
data.NotAfter == notAfter &&
// data.DNSNames[0] == &&
data.ValidPair) {
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
}
}

View File

@@ -0,0 +1,164 @@
package home
import (
"encoding/json"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/update"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
type getVersionJSONRequest struct {
RecheckNow bool `json:"recheck_now"`
}
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if Context.disableUpdate {
resp := make(map[string]interface{})
resp["disabled"] = true
d, _ := json.Marshal(resp)
_, _ = w.Write(d)
return
}
req := getVersionJSONRequest{}
var err error
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
}
var info update.VersionInfo
for i := 0; i != 3; i++ {
Context.controlLock.Lock()
info, err = Context.updater.GetVersionResponse(req.RecheckNow)
Context.controlLock.Unlock()
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
// This case may happen while we're restarting DNS server
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/934
continue
}
break
}
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(info))
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
// Perform an update procedure to the latest available version
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
if len(Context.updater.NewVersion) == 0 {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
err := Context.updater.DoUpdate()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
returnOK(w)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
go finishUpdate()
}
// Convert version.json data to our JSON response
func getVersionResp(info update.VersionInfo) []byte {
ret := make(map[string]interface{})
ret["can_autoupdate"] = false
ret["new_version"] = info.NewVersion
ret["announcement"] = info.Announcement
ret["announcement_url"] = info.AnnouncementURL
if info.CanAutoUpdate {
canUpdate := true
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if runtime.GOOS != "windows" &&
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
tlsConf.PortDNSOverTLS < 1024 ||
tlsConf.PortDNSOverQUIC < 1024)) ||
config.BindPort < 1024 ||
config.DNS.Port < 1024) {
// On UNIX, if we're running under a regular user,
// but with CAP_NET_BIND_SERVICE set on a binary file,
// and we're listening on ports <1024,
// we won't be able to restart after we replace the binary file,
// because we'll lose CAP_NET_BIND_SERVICE capability.
canUpdate, _ = util.HaveAdminRights()
}
ret["can_autoupdate"] = canUpdate
}
d, _ := json.Marshal(ret)
return d
}
// Complete an update procedure
func finishUpdate() {
log.Info("Stopping all tasks")
cleanup()
cleanupAlways()
exeName := "AdGuardHome"
if runtime.GOOS == "windows" {
exeName = "AdGuardHome.exe"
}
curBinName := filepath.Join(Context.workDir, exeName)
if runtime.GOOS == "windows" {
if Context.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
}
cmd := exec.Command(curBinName, os.Args[1:]...)
log.Info("Restarting: %v", cmd.Args)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
} else {
log.Info("Restarting: %v", os.Args)
err := syscall.Exec(curBinName, os.Args, os.Environ())
if err != nil {
log.Fatalf("syscall.Exec() failed: %s", err)
}
// Unreachable code
}
}

View File

@@ -0,0 +1,102 @@
// +build ignore
package home
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDoUpdate(t *testing.T) {
config.DNS.Port = 0
Context.workDir = "..." // set absolute path
newver := "v0.96"
data := `{
"version": "v0.96",
"announcement": "AdGuard Home v0.96 is now available!",
"announcement_url": "",
"download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip",
"download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip",
"download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
"download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip",
"download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
"download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
"download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
"download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
"download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
"download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
"download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
"download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
"download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
"download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
"download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
"download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
"download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
"download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
}`
uu, err := getUpdateInfo([]byte(data))
if err != nil {
t.Fatalf("getUpdateInfo: %s", err)
}
u := updateInfo{
pkgURL: "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
newVer: newver,
updateDir: Context.workDir + "/agh-update-" + newver,
backupDir: Context.workDir + "/agh-backup",
configName: Context.workDir + "/AdGuardHome.yaml",
updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome.yaml",
curBinName: Context.workDir + "/AdGuardHome",
bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome",
}
assert.Equal(t, uu.pkgURL, u.pkgURL)
assert.Equal(t, uu.pkgName, u.pkgName)
assert.Equal(t, uu.newVer, u.newVer)
assert.Equal(t, uu.updateDir, u.updateDir)
assert.Equal(t, uu.backupDir, u.backupDir)
assert.Equal(t, uu.configName, u.configName)
assert.Equal(t, uu.updateConfigName, u.updateConfigName)
assert.Equal(t, uu.curBinName, u.curBinName)
assert.Equal(t, uu.bkpBinName, u.bkpBinName)
assert.Equal(t, uu.newBinName, u.newBinName)
e := doUpdate(&u)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
os.RemoveAll(u.backupDir)
}
func TestTargzFileUnpack(t *testing.T) {
fn := "../dist/AdGuardHome_linux_amd64.tar.gz"
outdir := "../test-unpack"
defer os.RemoveAll(outdir)
_ = os.Mkdir(outdir, 0755)
files, e := targzFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
}
func TestZipFileUnpack(t *testing.T) {
fn := "../dist/AdGuardHome_windows_amd64.zip"
outdir := "../test-unpack"
_ = os.Mkdir(outdir, 0755)
files, e := zipFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
os.RemoveAll(outdir)
}

388
internal/home/dns.go Normal file
View File

@@ -0,0 +1,388 @@
package home
import (
"fmt"
"net"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
// Called by other modules when configuration is changed
func onConfigModified() {
_ = config.write()
}
// initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
func initDNSServer() error {
var err error
baseDir := Context.getDataDir()
statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.DNS.StatsInterval,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
}
Context.stats, err = stats.New(statsConf)
if err != nil {
return fmt.Errorf("couldn't initialize statistics module")
}
conf := querylog.Config{
Enabled: config.DNS.QueryLogEnabled,
FileEnabled: config.DNS.QueryLogFileEnabled,
BaseDir: baseDir,
Interval: config.DNS.QueryLogInterval,
MemSize: config.DNS.QueryLogMemSize,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
}
Context.queryLog = querylog.New(conf)
filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.AutoHosts = &Context.autoHosts
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = dnsfilter.New(&filterConf, nil)
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
}
Context.dnsServer = dnsforward.NewServer(p)
Context.clients.dnsServer = Context.dnsServer
dnsConfig := generateServerConfig()
err = Context.dnsServer.Prepare(&dnsConfig)
if err != nil {
closeDNSServer()
return fmt.Errorf("dnsServer.Prepare: %s", err)
}
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients)
Context.filters.Init()
return nil
}
func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
}
// nolint (gocyclo)
// Return TRUE if IP is within public Internet IP range
func isPublicIP(ip net.IP) bool {
ip4 := ip.To4()
if ip4 != nil {
switch ip4[0] {
case 0:
return false //software
case 10:
return false //private network
case 127:
return false //loopback
case 169:
if ip4[1] == 254 {
return false //link-local
}
case 172:
if ip4[1] >= 16 && ip4[1] <= 31 {
return false //private network
}
case 192:
if (ip4[1] == 0 && ip4[2] == 0) || //private network
(ip4[1] == 0 && ip4[2] == 2) || //documentation
(ip4[1] == 88 && ip4[2] == 99) || //reserved
(ip4[1] == 168) { //private network
return false
}
case 198:
if (ip4[1] == 18 || ip4[2] == 19) || //private network
(ip4[1] == 51 || ip4[2] == 100) { //documentation
return false
}
case 203:
if ip4[1] == 0 && ip4[2] == 113 { //documentation
return false
}
case 224:
if ip4[1] == 0 && ip4[2] == 0 { //multicast
return false
}
case 255:
if ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255 { //subnet
return false
}
}
} else {
if ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() {
return false
}
}
return true
}
func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.GetIPString(d.Addr)
if ip == "" {
// This would be quite weird if we get here
return
}
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
Context.rdns.Begin(ip)
}
if isPublicIP(ipAddr) {
Context.whois.Begin(ip)
}
}
func generateServerConfig() dnsforward.ServerConfig {
newconfig := dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
OnDNSRequest: onDNSRequest,
}
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled {
newconfig.TLSConfig = tlsConf.TLSConfig
if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{
IP: net.ParseIP(config.DNS.BindHost),
Port: tlsConf.PortDNSOverTLS,
}
}
if tlsConf.PortDNSOverQUIC != 0 {
newconfig.QUICListenAddr = &net.UDPAddr{
IP: net.ParseIP(config.DNS.BindHost),
Port: int(tlsConf.PortDNSOverQUIC),
}
}
}
newconfig.TLSv12Roots = Context.tlsRoots
newconfig.TLSCiphers = Context.tlsCiphers
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newconfig.FilterHandler = applyAdditionalFiltering
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
return newconfig
}
type DNSEncryption struct {
https string
tls string
quic string
}
func getDNSEncryption() DNSEncryption {
dnsEncryption := DNSEncryption{}
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
if tlsConf.PortHTTPS != 0 {
addr := tlsConf.ServerName
if tlsConf.PortHTTPS != 443 {
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
}
addr = fmt.Sprintf("https://%s/dns-query", addr)
dnsEncryption.https = addr
}
if tlsConf.PortDNSOverTLS != 0 {
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
dnsEncryption.tls = addr
}
if tlsConf.PortDNSOverQUIC != 0 {
addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC)
dnsEncryption.quic = addr
}
}
return dnsEncryption
}
// Get the list of DNS addresses the server is listening on
func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
addDNSAddress(&dnsAddresses, addr)
}
}
} else {
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
}
dnsEncryption := getDNSEncryption()
if dnsEncryption.https != "" {
dnsAddresses = append(dnsAddresses, dnsEncryption.https)
}
if dnsEncryption.tls != "" {
dnsAddresses = append(dnsAddresses, dnsEncryption.tls)
}
if dnsEncryption.quic != "" {
dnsAddresses = append(dnsAddresses, dnsEncryption.quic)
}
return dnsAddresses
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 {
return
}
setts.ClientIP = clientAddr
c, ok := Context.clients.Find(clientAddr)
if !ok {
return
}
log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
}
setts.ClientName = c.Name
setts.ClientTags = c.Tags
if !c.UseOwnSettings {
return
}
setts.FilteringEnabled = c.FilteringEnabled
setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
}
func startDNSServer() error {
if isRunning() {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
enableFilters(false)
Context.clients.Start()
err := Context.dnsServer.Start()
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
Context.dnsFilter.Start()
Context.filters.Start()
Context.stats.Start()
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
for _, ip := range topClients {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
Context.rdns.Begin(ip)
}
if isPublicIP(ipAddr) {
Context.whois.Begin(ip)
}
}
return nil
}
func reconfigureDNSServer() error {
newconfig := generateServerConfig()
err := Context.dnsServer.Reconfigure(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
return nil
}
func stopDNSServer() error {
if !isRunning() {
return nil
}
err := Context.dnsServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
}
closeDNSServer()
return nil
}
func closeDNSServer() {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
if Context.dnsServer != nil {
Context.dnsServer.Close()
Context.dnsServer = nil
}
if Context.dnsFilter != nil {
Context.dnsFilter.Close()
Context.dnsFilter = nil
}
if Context.stats != nil {
Context.stats.Close()
Context.stats = nil
}
if Context.queryLog != nil {
Context.queryLog.Close()
Context.queryLog = nil
}
Context.filters.Close()
log.Debug("Closed all DNS modules")
}

713
internal/home/filter.go Normal file
View File

@@ -0,0 +1,713 @@
package home
import (
"bufio"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
var (
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
)
// Filtering - module object
type Filtering struct {
// conf FilteringConf
refreshStatus uint32 // 0:none; 1:in progress
refreshLock sync.Mutex
filterTitleRegexp *regexp.Regexp
}
// Init - initialize the module
func (f *Filtering) Init() {
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0755)
f.loadFilters(config.Filters)
f.loadFilters(config.WhitelistFilters)
deduplicateFilters()
updateUniqueFilterID(config.Filters)
updateUniqueFilterID(config.WhitelistFilters)
}
// Start - start the module
func (f *Filtering) Start() {
f.RegisterFilteringHandlers()
// Here we should start updating filters,
// but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here.
go f.periodicallyRefreshFilters()
}
// Close - close the module
func (f *Filtering) Close() {
}
func defaultFilters() []filter {
return []filter{
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
}
}
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool
URL string // URL or a file path
Name string `yaml:"name"`
RulesCount int `yaml:"-"`
LastUpdated time.Time `yaml:"-"`
checksum uint32 // checksum of the file data
white bool
dnsfilter.Filter `yaml:",inline"`
}
// Creates a helper object for working with the user rules
func userFilter() filter {
f := filter{
// User filter always has constant ID=0
Enabled: true,
}
f.Filter.Data = []byte(strings.Join(config.UserRules, "\n"))
return f
}
const (
statusFound = 1
statusEnabledChanged = 2
statusURLChanged = 4
statusURLExists = 8
statusUpdateRequired = 0x10
)
// Update properties for a filter specified by its URL
// Return status* flags.
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
r := 0
config.Lock()
defer config.Unlock()
filters := &config.Filters
if whitelist {
filters = &config.WhitelistFilters
}
for i := range *filters {
filt := &(*filters)[i]
if filt.URL != url {
continue
}
log.Debug("filter: set properties: %s: {%s %s %v}",
filt.URL, newf.Name, newf.URL, newf.Enabled)
filt.Name = newf.Name
if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired
if filterExistsNoLock(newf.URL) {
return statusURLExists
}
filt.URL = newf.URL
filt.unload()
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
}
if filt.Enabled != newf.Enabled {
r |= statusEnabledChanged
filt.Enabled = newf.Enabled
if filt.Enabled {
if (r & statusURLChanged) == 0 {
e := f.load(filt)
if e != nil {
// This isn't a fatal error,
// because it may occur when someone removes the file from disk.
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
r |= statusUpdateRequired
}
}
} else {
filt.unload()
}
}
return r | statusFound
}
return 0
}
// Return TRUE if a filter with this URL exists
func filterExists(url string) bool {
config.RLock()
r := filterExistsNoLock(url)
config.RUnlock()
return r
}
func filterExistsNoLock(url string) bool {
for _, f := range config.Filters {
if f.URL == url {
return true
}
}
for _, f := range config.WhitelistFilters {
if f.URL == url {
return true
}
}
return false
}
// Add a filter
// Return FALSE if a filter with this URL exists
func filterAdd(f filter) bool {
config.Lock()
defer config.Unlock()
// Check for duplicates
if filterExistsNoLock(f.URL) {
return false
}
if f.white {
config.WhitelistFilters = append(config.WhitelistFilters, f)
} else {
config.Filters = append(config.Filters, f)
}
return true
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
func (f *Filtering) loadFilters(array []filter) {
for i := range array {
filter := &array[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled {
// No need to load a filter that is not enabled
continue
}
err := f.load(filter)
if err != nil {
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
}
func deduplicateFilters() {
// Deduplicate filters
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID++
return value
}
// Sets up a timer that will be checking for filters updates periodically
func (f *Filtering) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval
for {
isNetworkErr := false
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
f.refreshLock.Lock()
_, isNetworkErr = f.refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists)
f.refreshLock.Unlock()
f.refreshStatus = 0
if !isNetworkErr {
intval = maxInterval
}
}
if isNetworkErr {
intval *= 2
if intval > maxInterval {
intval = maxInterval
}
}
time.Sleep(time.Duration(intval) * time.Second)
}
}
// Refresh filters
// flags: FilterRefresh*
// important:
// TRUE: ignore the fact that we're currently updating the filters
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
if !important && !set {
return 0, fmt.Errorf("filters update procedure is already running")
}
f.refreshLock.Lock()
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
f.refreshLock.Unlock()
f.refreshStatus = 0
return nUpdated, nil
}
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
var updateFilters []filter
var updateFlags []bool // 'true' if filter data has changed
now := time.Now()
config.RLock()
for i := range *filters {
f := &(*filters)[i] // otherwise we will be operating on a copy
if !f.Enabled {
continue
}
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
if !force && expireTime > now.Unix() {
continue
}
var uf filter
uf.ID = f.ID
uf.URL = f.URL
uf.Name = f.Name
uf.checksum = f.checksum
updateFilters = append(updateFilters, uf)
}
config.RUnlock()
if len(updateFilters) == 0 {
return 0, nil, nil, false
}
nfail := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated, err := f.update(uf)
updateFlags = append(updateFlags, updated)
if err != nil {
nfail++
log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
continue
}
}
if nfail == len(updateFilters) {
return 0, nil, nil, true
}
updateCount := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
config.Lock()
for k := range *filters {
f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL {
continue
}
f.LastUpdated = uf.LastUpdated
if !updated {
continue
}
log.Info("Updated filter #%d. Rules: %d -> %d",
f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name
f.RulesCount = uf.RulesCount
f.checksum = uf.checksum
updateCount++
}
config.Unlock()
}
return updateCount, updateFilters, updateFlags, false
}
const (
FilterRefreshForce = 1 // ignore last file modification date
FilterRefreshAllowlists = 2 // update allow-lists
FilterRefreshBlocklists = 4 // update block-lists
)
// Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value
// flags: FilterRefresh*
//
// Algorithm:
// . Get the list of filters to be updated
// . For each filter run the download and checksum check operation
// . Store downloaded data in a temporary file inside data/filters directory
// . For each filter:
// . If filter data hasn't changed, just set new update time on file
// . If filter data has changed:
// . rename the temporary file (<temp> -> 1.txt)
// Note that this method works only on UNIX.
// On Windows we don't pass files to dnsfilter - we pass the whole data.
// . Pass new filters to dnsfilter object - it analyzes new data while the old filters are still active
// . dnsfilter activates new filters
//
// Return the number of updated filters
// Return TRUE - there was a network error and nothing could be updated
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
log.Debug("Filters: updating...")
updateCount := 0
var updateFilters []filter
var updateFlags []bool
netError := false
netErrorW := false
force := false
if (flags & FilterRefreshForce) != 0 {
force = true
}
if (flags & FilterRefreshBlocklists) != 0 {
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
}
if (flags & FilterRefreshAllowlists) != 0 {
updateCountW := 0
var updateFiltersW []filter
var updateFlagsW []bool
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
updateCount += updateCountW
updateFilters = append(updateFilters, updateFiltersW...)
updateFlags = append(updateFlags, updateFlagsW...)
}
if netError && netErrorW {
return 0, true
}
if updateCount != 0 {
enableFilters(false)
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
if !updated {
continue
}
_ = os.Remove(uf.Path() + ".old")
}
}
log.Debug("Filters: update finished")
return updateCount, false
}
// Allows printable UTF-8 text with CR, LF, TAB characters
func isPrintableText(data []byte, len int) bool {
for i := 0; i < len; i++ {
c := data[i]
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
continue
}
return false
}
return true
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
rulesCount := 0
name := ""
seenTitle := false
r := bufio.NewReader(file)
checksum := uint32(0)
for {
line, err := r.ReadString('\n')
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
line = strings.TrimSpace(line)
if len(line) == 0 {
//
} else if line[0] == '!' {
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else if line[0] == '#' {
//
} else {
rulesCount++
}
if err != nil {
break
}
}
return rulesCount, checksum, name
}
// Perform upgrade on a filter and update LastUpdated value
func (f *Filtering) update(filter *filter) (bool, error) {
b, err := f.updateIntl(filter)
filter.LastUpdated = time.Now()
if !b {
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
}
return b, err
}
// nolint(gocyclo)
func (f *Filtering) updateIntl(filter *filter) (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
tmpFile, err := ioutil.TempFile(filepath.Join(Context.getDataDir(), filterDir), "")
if err != nil {
return false, err
}
defer func() {
if tmpFile != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpFile.Name())
}
}()
var reader io.Reader
if filepath.IsAbs(filter.URL) {
f, err := os.Open(filter.URL)
if err != nil {
return false, fmt.Errorf("open file: %s", err)
}
defer f.Close()
reader = f
} else {
resp, err := Context.client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return false, err
}
if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
reader = resp.Body
}
htmlTest := true
firstChunk := make([]byte, 4*1024)
firstChunkLen := 0
buf := make([]byte, 64*1024)
total := 0
for {
n, err := reader.Read(buf)
total += n
if htmlTest {
// gather full buffer firstChunk and perform its data tests
num := util.MinInt(n, len(firstChunk)-firstChunkLen)
copied := copy(firstChunk[firstChunkLen:], buf[:num])
firstChunkLen += copied
if firstChunkLen == len(firstChunk) || err == io.EOF {
if !isPrintableText(firstChunk, firstChunkLen) {
return false, fmt.Errorf("data contains non-printable characters")
}
s := strings.ToLower(string(firstChunk))
if strings.Index(s, "<html") >= 0 ||
strings.Index(s, "<!doctype") >= 0 {
return false, fmt.Errorf("data is HTML, not plain text")
}
htmlTest = false
firstChunk = nil
}
}
_, err2 := tmpFile.Write(buf[:n])
if err2 != nil {
return false, err2
}
if err == io.EOF {
break
}
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return false, err
}
}
// Extract filter name and count number of rules
_, _ = tmpFile.Seek(0, io.SeekStart)
rulesCount, checksum, filterName := f.parseFilterContents(tmpFile)
// Check if the filter has been really changed
if filter.checksum == checksum {
log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
return false, nil
}
log.Printf("Filter %d has been updated: %d bytes, %d rules",
filter.ID, total, rulesCount)
if len(filter.Name) == 0 {
filter.Name = filterName
}
filter.RulesCount = rulesCount
filter.checksum = checksum
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
// Closing the file before renaming it is necessary on Windows
_ = tmpFile.Close()
err = os.Rename(tmpFile.Name(), filterFilePath)
if err != nil {
return false, err
}
tmpFile = nil
return true, nil
}
// loads filter contents from the file in dataDir
func (f *Filtering) load(filter *filter) error {
filterFilePath := filter.Path()
log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err
}
file, err := os.Open(filterFilePath)
if err != nil {
return err
}
defer file.Close()
st, _ := file.Stat()
log.Tracef("File %s, id %d, length %d",
filterFilePath, filter.ID, st.Size())
rulesCount, checksum, _ := f.parseFilterContents(file)
filter.RulesCount = rulesCount
filter.checksum = checksum
filter.LastUpdated = filter.LastTimeUpdated()
return nil
}
// Clear filter rules
func (filter *filter) unload() {
filter.RulesCount = 0
filter.checksum = 0
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
// LastTimeUpdated returns the time when the filter was last time updated
func (filter *filter) LastTimeUpdated() time.Time {
filterFilePath := filter.Path()
s, err := os.Stat(filterFilePath)
if os.IsNotExist(err) {
// if the filter file does not exist, return 0001-01-01
return time.Time{}
}
if err != nil {
// if the filter file does not exist, return 0001-01-01
return time.Time{}
}
// filter file modified time
return s.ModTime()
}
func enableFilters(async bool) {
var filters []dnsfilter.Filter
var whiteFilters []dnsfilter.Filter
if config.DNS.FilteringEnabled {
// convert array of filters
userFilter := userFilter()
f := dnsfilter.Filter{
ID: userFilter.ID,
Data: userFilter.Data,
}
filters = append(filters, f)
for _, filter := range config.Filters {
if !filter.Enabled {
continue
}
f := dnsfilter.Filter{
ID: filter.ID,
FilePath: filter.Path(),
}
filters = append(filters, f)
}
for _, filter := range config.WhitelistFilters {
if !filter.Enabled {
continue
}
f := dnsfilter.Filter{
ID: filter.ID,
FilePath: filter.Path(),
}
whiteFilters = append(whiteFilters, f)
}
}
_ = Context.dnsFilter.SetFilters(filters, whiteFilters, async)
}

View File

@@ -0,0 +1,65 @@
package home
import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func testStartFilterListener() net.Listener {
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() { _ = http.Serve(listener, nil) }()
return listener
}
func TestFilters(t *testing.T) {
l := testStartFilterListener()
defer func() { _ = l.Close() }()
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{}
Context.workDir = dir
Context.client = &http.Client{
Timeout: 5 * time.Second,
}
Context.filters.Init()
f := filter{
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
}
// download
ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err)
assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil)
err = Context.filters.load(&f)
assert.True(t, err == nil)
f.unload()
_ = os.Remove(f.Path())
}

668
internal/home/home.go Normal file
View File

@@ -0,0 +1,668 @@
// Package home contains AdGuard Home's HTTP API methods.
package home
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"sync"
"syscall"
"time"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/AdguardTeam/AdGuardHome/internal/update"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/joomcode/errorx"
"github.com/AdguardTeam/AdGuardHome/internal/isdelve"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/golibs/log"
)
const (
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
configSyslog = "syslog"
)
// Update-related variables
var (
versionString = "dev"
updateChannel = "none"
versionCheckURL = ""
ARMVersion = ""
)
// Global context
type homeContext struct {
// Modules
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *Whois // WHOIS module
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
updater *update.Updater
// Runtime properties
// --
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
pidFileName string // PID file name. Empty if no PID file was created.
disableUpdate bool // If set, don't check for updates
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
tlsCiphers []uint16 // list of TLS ciphers to use
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
}
// getDataDir returns path to the directory where we store databases and filters
func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir)
}
// Context - a global context object
var Context homeContext
// Main is the entry point
func Main(version string, channel string, armVer string) {
// Init update-related global variables
versionString = version
updateChannel = channel
ARMVersion = armVer
versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
// config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib
args := loadOptions()
Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
for {
sig := <-Context.appSignalChannel
log.Info("Received signal '%s'", sig)
switch sig {
case syscall.SIGHUP:
Context.clients.Reload()
Context.tls.Reload()
default:
cleanup()
cleanupAlways()
os.Exit(0)
}
}
}()
if args.serviceControlAction != "" {
handleServiceControlAction(args)
return
}
// run the protection
run(args)
}
// version - returns the current version string
func version() string {
msg := "AdGuard Home, version %s, channel %s, arch %s %s"
if ARMVersion != "" {
msg = msg + " v" + ARMVersion
}
return fmt.Sprintf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH)
}
// run initializes configuration and runs the AdGuard Home
// run is a blocking method!
// nolint
func run(args options) {
// configure config filename
initConfigFilename(args)
// configure working dir and config path
initWorkingDir(args)
// configure log level and output
configureLogger(args)
// Go memory hacks
memoryUsage(args)
// print the first message after logger is configured
log.Println(version())
log.Debug("Current working directory is %s", Context.workDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
}
Context.runningAsService = args.runningAsService
Context.disableUpdate = args.disableUpdate
Context.firstRun = detectFirstRun()
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
}
initConfig()
Context.tlsRoots = util.LoadSystemRootCAs()
Context.tlsCiphers = util.InitTLSCiphers()
Context.transport = &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
},
}
Context.client = &http.Client{
Timeout: time.Minute * 5,
Transport: Context.transport,
}
if !Context.firstRun {
// Do the upgrade if necessary
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
err = parseConfig()
if err != nil {
log.Error("Failed to parse configuration, exiting")
os.Exit(1)
}
if args.checkConfig {
log.Info("Configuration file is OK")
os.Exit(0)
}
}
// 'clients' module uses 'dnsfilter' module's static data (dnsfilter.BlockedSvcKnown()),
// so we have to initialize dnsfilter's static data first,
// but also avoid relying on automatic Go init() function
dnsfilter.InitModule()
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
if runtime.GOOS != "windows" {
Context.dhcpServer = dhcpd.Create(config.DHCP)
if Context.dhcpServer == nil {
log.Fatalf("Can't initialize DHCP module")
}
}
Context.autoHosts.Init("")
Context.updater = update.NewUpdater(update.Config{
Client: Context.client,
WorkDir: Context.workDir,
VersionURL: versionCheckURL,
VersionString: versionString,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
ARMVersion: ARMVersion,
ConfigName: config.getConfigFilename(),
})
Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts)
config.Clients = nil
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
config.RlimitNoFile != 0 {
util.SetRlimit(config.RlimitNoFile)
}
// override bind host/port from the console
if args.bindHost != "" {
config.BindHost = args.bindHost
}
if args.bindPort != 0 {
config.BindPort = args.bindPort
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
Context.pidFileName = args.pidFile
}
if !Context.firstRun {
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
if config.DebugPProf {
mux := http.NewServeMux()
util.PProfRegisterWebHandlers(mux)
go func() {
log.Info("pprof: listening on localhost:6060")
err := http.ListenAndServe("localhost:6060", mux)
log.Error("Error while running the pprof server: %s", err)
}()
}
}
err := os.MkdirAll(Context.getDataDir(), 0755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = args.glinetMode
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
log.Fatalf("Couldn't initialize Auth module")
}
config.Users = nil
Context.tls = tlsCreate(config.TLS)
if Context.tls == nil {
log.Fatalf("Can't initialize TLS module")
}
webConf := WebConfig{
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
}
Context.web = CreateWeb(&webConf)
if Context.web == nil {
log.Fatalf("Can't initialize Web module")
}
if !Context.firstRun {
err := initDNSServer()
if err != nil {
log.Fatalf("%s", err)
}
Context.tls.Start()
Context.autoHosts.Start()
go func() {
err := startDNSServer()
if err != nil {
log.Fatal(err)
}
}()
if Context.dhcpServer != nil {
_ = Context.dhcpServer.Start()
}
}
Context.web.Start()
// wait indefinitely for other go-routines to complete their job
select {}
}
// StartMods - initialize and start DNS after installation
func StartMods() error {
err := initDNSServer()
if err != nil {
return err
}
Context.tls.Start()
err = startDNSServer()
if err != nil {
closeDNSServer()
return err
}
return nil
}
// Check if the current user permissions are enough to run AdGuard Home
func checkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions")
if runtime.GOOS == "windows" {
// On Windows we need to have admin rights to run properly
admin, _ := util.HaveAdminRights()
if //noinspection ALL
admin || isdelve.Enabled {
// Don't forget that for this to work you need to add "delve" tag explicitly
// https://stackoverflow.com/questions/47879070/how-can-i-see-if-the-goland-debugger-is-running-in-the-program
return
}
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
}
// We should check if AdGuard Home is able to bind to port 53
ok, err := util.CanBindPort(53)
if ok {
log.Info("AdGuard Home can bind to port 53")
return
}
if opErr, ok := err.(*net.OpError); ok {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
if errno, ok := sysErr.Err.(syscall.Errno); ok && errno == syscall.EACCES {
msg := `Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports.
You have two options:
1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
log.Fatal(msg)
}
}
}
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
log.Info(msg)
}
// Write PID to a file
func writePIDFile(fn string) bool {
data := fmt.Sprintf("%d", os.Getpid())
err := ioutil.WriteFile(fn, []byte(data), 0644)
if err != nil {
log.Error("Couldn't write PID to file %s: %v", fn, err)
return false
}
return true
}
func initConfigFilename(args options) {
// config file path can be overridden by command-line arguments:
if args.configFilename != "" {
Context.configFilename = args.configFilename
} else {
// Default config file name
Context.configFilename = "AdGuardHome.yaml"
}
}
// initWorkingDir initializes the workDir
// if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(args options) {
execPath, err := os.Executable()
if err != nil {
panic(err)
}
if args.workDir != "" {
// If there is a custom config file, use it's directory as our working dir
Context.workDir = args.workDir
} else {
Context.workDir = filepath.Dir(execPath)
}
}
// configureLogger configures logger level and output
func configureLogger(args options) {
ls := getLogSettings()
// command-line arguments can override config settings
if args.verbose || config.Verbose {
ls.Verbose = true
}
if args.logFile != "" {
ls.LogFile = args.logFile
} else if config.LogFile != "" {
ls.LogFile = config.LogFile
}
// Handle default log settings overrides
ls.LogCompress = config.LogCompress
ls.LogLocalTime = config.LogLocalTime
ls.LogMaxBackups = config.LogMaxBackups
ls.LogMaxSize = config.LogMaxSize
ls.LogMaxAge = config.LogMaxAge
// log.SetLevel(log.INFO) - default
if ls.Verbose {
log.SetLevel(log.DEBUG)
}
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing else is configured
// Otherwise, we'll simply loose the log output
ls.LogFile = configSyslog
}
// logs are written to stdout (default)
if ls.LogFile == "" {
return
}
if ls.LogFile == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := util.ConfigureSyslog(serviceName)
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile
}
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
log.Fatalf("cannot create a log file: %s", err)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.LogCompress, // disabled by default
LocalTime: ls.LogLocalTime,
MaxBackups: ls.LogMaxBackups,
MaxSize: ls.LogMaxSize, // megabytes
MaxAge: ls.LogMaxAge, //days
})
}
}
func cleanup() {
log.Info("Stopping AdGuard Home")
if Context.web != nil {
Context.web.Close()
Context.web = nil
}
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
err := stopDNSServer()
if err != nil {
log.Error("Couldn't stop DNS server: %s", err)
}
if Context.dhcpServer != nil {
Context.dhcpServer.Stop()
}
Context.autoHosts.Close()
if Context.tls != nil {
Context.tls.Close()
Context.tls = nil
}
}
// This function is called before application exits
func cleanupAlways() {
if len(Context.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName)
}
log.Info("Stopped")
}
func exitWithError() {
os.Exit(64)
}
// loadOptions reads command line arguments and initializes configuration
func loadOptions() options {
o, f, err := parse(os.Args[0], os.Args[1:])
if err != nil {
log.Error(err.Error())
_ = printHelp(os.Args[0])
exitWithError()
} else if f != nil {
err = f()
if err != nil {
log.Error(err.Error())
exitWithError()
} else {
os.Exit(0)
}
}
return o
}
// prints IP addresses which user can use to open the admin interface
// proto is either "http" or "https"
func printHTTPAddresses(proto string) {
var address string
tlsConf := tlsConfigSettings{}
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
}
port := strconv.Itoa(config.BindPort)
if proto == "https" {
port = strconv.Itoa(tlsConf.PortHTTPS)
}
if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
}
} else if config.BindHost == "0.0.0.0" {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
return
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
address = net.JoinHostPort(addr, strconv.Itoa(config.BindPort))
log.Printf("Go to %s://%s", proto, address)
}
}
} else {
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
}
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := Context.configFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(Context.workDir, Context.configFilename)
}
_, err := os.Stat(configfile)
if !os.IsNotExist(err) {
// do nothing, file exists
return false
}
return true
}
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
addrs, e := Context.dnsServer.Resolve(host)
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
if e != nil {
return nil, e
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
return con, err
}
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
func getHTTPProxy(req *http.Request) (*url.URL, error) {
if len(config.ProxyURL) == 0 {
return nil, nil
}
return url.Parse(config.ProxyURL)
}

189
internal/home/home_test.go Normal file
View File

@@ -0,0 +1,189 @@
// +build !race
package home
import (
"context"
"encoding/base64"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
const yamlConf = `bind_host: 127.0.0.1
bind_port: 3000
users: []
language: en
rlimit_nofile: 0
web_session_ttl: 720
dns:
bind_host: 127.0.0.1
port: 5354
statistics_interval: 90
querylog_enabled: true
querylog_interval: 90
querylog_size_memory: 0
protection_enabled: true
blocking_mode: null_ip
blocked_response_ttl: 0
ratelimit: 100
ratelimit_whitelist: []
refuse_any: false
bootstrap_dns:
- 1.1.1.1:53
all_servers: false
allowed_clients: []
disallowed_clients: []
blocked_hosts: []
parental_block_host: family-block.dns.adguard.com
safebrowsing_block_host: standard-block.dns.adguard.com
cache_size: 0
upstream_dns:
- https://1.1.1.1/dns-query
filtering_enabled: true
filters_update_interval: 168
parental_sensitivity: 13
parental_enabled: true
safesearch_enabled: false
safebrowsing_enabled: false
safebrowsing_cache_size: 1048576
safesearch_cache_size: 1048576
parental_cache_size: 1048576
cache_time: 30
rewrites: []
blocked_services: []
tls:
enabled: false
server_name: www.example.com
force_https: false
port_https: 443
port_dns_over_tls: 853
allow_unencrypted_doh: true
certificate_chain: ""
private_key: ""
certificate_path: ""
private_key_path: ""
filters:
- enabled: true
url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
name: AdGuard Simplified Domain Names filter
id: 1
- enabled: false
url: https://hosts-file.net/ad_servers.txt
name: hpHosts - Ad and Tracking servers only
id: 2
- enabled: false
url: https://adaway.org/hosts.txt
name: adaway
id: 3
user_rules:
- ""
dhcp:
enabled: false
interface_name: ""
gateway_ip: ""
subnet_mask: ""
range_start: ""
range_end: ""
lease_duration: 86400
icmp_timeout_msec: 1000
clients: []
log_file: ""
verbose: false
schema_version: 5
`
// . Create a configuration file
// . Start AGH instance
// . Check Web server
// . Check DNS server
// . Check DNS server with DOH
// . Wait until the filters are downloaded
// . Stop and cleanup
func TestHome(t *testing.T) {
// Init new context
Context = homeContext{}
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0644) == nil)
fn, _ = filepath.Abs(fn)
config = configuration{} // the global variable is dirty because of the previous tests run
args := options{}
args.configFilename = fn
args.workDir = dir
go run(args)
var err error
var resp *http.Response
h := http.Client{}
for i := 0; i != 50; i++ {
resp, err = h.Get("http://127.0.0.1:3000/")
if err == nil && resp.StatusCode != 404 {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = h.Get("http://127.0.0.1:3000/control/status")
assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// test DNS over UDP
r, err := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
assert.Nil(t, err)
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
assert.Nil(t, err)
haveIP := len(addrs) != 0
assert.True(t, haveIP)
// test DNS over HTTP without encryption
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
buf, err := req.Pack()
assert.True(t, err == nil, "%s", err)
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
resp, err = http.DefaultClient.Get(requestURL)
assert.True(t, err == nil, "%s", err)
body, err := ioutil.ReadAll(resp.Body)
assert.True(t, err == nil, "%s", err)
assert.True(t, resp.StatusCode == http.StatusOK)
response := dns.Msg{}
err = response.Unpack(body)
assert.True(t, err == nil, "%s", err)
addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0
assert.True(t, haveIP)
for i := 1; ; i++ {
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
if err == nil && st.Size() != 0 {
break
}
if i == 5 {
assert.True(t, false)
break
}
time.Sleep(1 * time.Second)
}
cleanup()
cleanupAlways()
}

94
internal/home/i18n.go Normal file
View File

@@ -0,0 +1,94 @@
package home
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/AdguardTeam/golibs/log"
)
// --------------------
// internationalization
// --------------------
var allowedLanguages = map[string]bool{
"be": true,
"bg": true,
"cs": true,
"da": true,
"de": true,
"en": true,
"es": true,
"fa": true,
"fr": true,
"hr": true,
"hu": true,
"id": true,
"it": true,
"ja": true,
"ko": true,
"nl": true,
"no": true,
"pl": true,
"pt-br": true,
"pt-pt": true,
"ro": true,
"ru": true,
"si-lk": true,
"sk": true,
"sl": true,
"sr-cs": true,
"sv": true,
"th": true,
"tr": true,
"vi": true,
"zh-cn": true,
"zh-hk": true,
"zh-tw": true,
}
func isLanguageAllowed(language string) bool {
l := strings.ToLower(language)
return allowedLanguages[l]
}
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
}
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
language := strings.TrimSpace(string(body))
if language == "" {
errorText := fmt.Sprintf("empty language specified")
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
if !isLanguageAllowed(language) {
errorText := fmt.Sprintf("unknown language specified: %s", language)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
config.Language = language
onConfigModified()
returnOK(w)
}

44
internal/home/memory.go Normal file
View File

@@ -0,0 +1,44 @@
package home
import (
"os"
"runtime/debug"
"time"
"github.com/AdguardTeam/golibs/log"
)
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
// See this for the details on the performance hits & gains:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
func memoryUsage(args options) {
if args.disableMemoryOptimization {
log.Info("Memory optimization is disabled")
return
}
// Makes Go allocate heap at a slower pace
// By default we keep it at 50%
debug.SetGCPercent(50)
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
// instead of MADV_FREE on Linux when returning memory to the
// kernel. This is less efficient, but causes RSS numbers to drop
// more quickly.
_ = os.Setenv("GODEBUG", "madvdontneed=1")
// periodically call "debug.FreeOSMemory" so
// that the OS could reclaim the free memory
go func() {
ticker := time.NewTicker(5 * time.Minute)
for {
select {
case t := <-ticker.C:
t.Second()
log.Debug("Free OS memory")
debug.FreeOSMemory()
}
}
}()
}

View File

@@ -0,0 +1,113 @@
package home
import (
"fmt"
"net"
"net/http"
uuid "github.com/satori/go.uuid"
"howett.net/plist"
)
type DNSSettings struct {
DNSProtocol string
ServerURL string `plist:",omitempty"`
ServerName string `plist:",omitempty"`
}
type PayloadContent struct {
Name string
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadType string
PayloadUUID string
PayloadVersion int
DNSSettings DNSSettings
}
type MobileConfig struct {
PayloadContent []PayloadContent
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadRemovalDisallowed bool
PayloadType string
PayloadUUID string
PayloadVersion int
}
func genUUIDv4() string {
return uuid.NewV4().String()
}
const (
dnsProtoHTTPS = "HTTPS"
dnsProtoTLS = "TLS"
)
func getMobileConfig(d DNSSettings) ([]byte, error) {
var name string
switch d.DNSProtocol {
case dnsProtoHTTPS:
name = fmt.Sprintf("%s DoH", d.ServerName)
case dnsProtoTLS:
name = fmt.Sprintf("%s DoT", d.ServerName)
default:
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
}
data := MobileConfig{
PayloadContent: []PayloadContent{{
Name: name,
PayloadDescription: "Configures device to use AdGuard Home",
PayloadDisplayName: name,
PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()),
PayloadType: "com.apple.dnsSettings.managed",
PayloadUUID: genUUIDv4(),
PayloadVersion: 1,
DNSSettings: d,
}},
PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems",
PayloadDisplayName: name,
PayloadIdentifier: genUUIDv4(),
PayloadRemovalDisallowed: false,
PayloadType: "Configuration",
PayloadUUID: genUUIDv4(),
PayloadVersion: 1,
}
return plist.MarshalIndent(data, plist.XMLFormat, "\t")
}
func handleMobileConfig(w http.ResponseWriter, d DNSSettings) {
mobileconfig, err := getMobileConfig(d)
if err != nil {
httpError(w, http.StatusInternalServerError, "plist.MarshalIndent: %s", err)
}
w.Header().Set("Content-Type", "application/xml")
_, _ = w.Write(mobileconfig)
}
func handleMobileConfigDoh(w http.ResponseWriter, r *http.Request) {
handleMobileConfig(w, DNSSettings{
DNSProtocol: dnsProtoHTTPS,
ServerURL: fmt.Sprintf("https://%s/dns-query", r.Host),
})
}
func handleMobileConfigDot(w http.ResponseWriter, r *http.Request) {
var err error
var host string
host, _, err = net.SplitHostPort(r.Host)
if err != nil {
httpError(w, http.StatusBadRequest, "getting host: %s", err)
}
handleMobileConfig(w, DNSSettings{
DNSProtocol: dnsProtoTLS,
ServerName: host,
})
}

View File

@@ -0,0 +1,33 @@
package home
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"howett.net/plist"
)
func TestHandleMobileConfigDot(t *testing.T) {
var err error
var r *http.Request
r, err = http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
assert.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDot(w, r)
assert.Equal(t, http.StatusOK, w.Code)
var mc MobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
assert.Equal(t, "example.com DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.com DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.com", mc.PayloadContent[0].DNSSettings.ServerName)
}
}

325
internal/home/options.go Normal file
View File

@@ -0,0 +1,325 @@
package home
import (
"fmt"
"os"
"strconv"
)
// options passed from command-line arguments
type options struct {
verbose bool // is verbose logging enabled
configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog
bindHost string // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
checkConfig bool // Check configuration and exit
disableUpdate bool // If set, don't check for updates
// service control action (see service.ControlAction array + "status" command)
serviceControlAction string
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
// disableMemoryOptimization - disables memory optimization hacks
// see memoryUsage() function for the details
disableMemoryOptimization bool
glinetMode bool // Activate GL-Inet compatibility mode
}
// functions used for their side-effects
type effect func() error
type arg struct {
description string // a short, English description of the argument
longName string // the name of the argument used after '--'
shortName string // the name of the argument used after '-'
// only one of updateWithValue, updateNoValue, and effect should be present
updateWithValue func(o options, v string) (options, error) // the mutator for arguments with parameters
updateNoValue func(o options) (options, error) // the mutator for arguments without parameters
effect func(o options, exec string) (f effect, err error) // the side-effect closure generator
serialize func(o options) []string // the re-serialization function back to arguments (return nil for omit)
}
// {type}SliceOrNil functions check their parameter of type {type}
// against its zero value and return nil if the parameter value is
// zero otherwise they return a string slice of the parameter
func stringSliceOrNil(s string) []string {
if s == "" {
return nil
}
return []string{s}
}
func intSliceOrNil(i int) []string {
if i == 0 {
return nil
}
return []string{strconv.Itoa(i)}
}
func boolSliceOrNil(b bool) []string {
if b {
return []string{}
}
return nil
}
var args []arg
var configArg = arg{
"Path to the config file",
"config", "c",
func(o options, v string) (options, error) { o.configFilename = v; return o, nil },
nil,
nil,
func(o options) []string { return stringSliceOrNil(o.configFilename) },
}
var workDirArg = arg{
"Path to the working directory",
"work-dir", "w",
func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.workDir) },
}
var hostArg = arg{
"Host address to bind HTTP server on",
"host", "h",
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.bindHost) },
}
var portArg = arg{
"Port to serve HTTP pages on",
"port", "p",
func(o options, v string) (options, error) {
var err error
var p int
minPort, maxPort := 0, 1<<16-1
if p, err = strconv.Atoi(v); err != nil {
err = fmt.Errorf("port '%s' is not a number", v)
} else if p < minPort || p > maxPort {
err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
} else {
o.bindPort = p
}
return o, err
}, nil, nil,
func(o options) []string { return intSliceOrNil(o.bindPort) },
}
var serviceArg = arg{
"Service control action: status, install, uninstall, start, stop, restart, reload (configuration)",
"service", "s",
func(o options, v string) (options, error) {
o.serviceControlAction = v
return o, nil
}, nil, nil,
func(o options) []string { return stringSliceOrNil(o.serviceControlAction) },
}
var logfileArg = arg{
"Path to log file. If empty: write to stdout; if 'syslog': write to system log",
"logfile", "l",
func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.logFile) },
}
var pidfileArg = arg{
"Path to a file where PID is stored",
"pidfile", "",
func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.pidFile) },
}
var checkConfigArg = arg{
"Check configuration and exit",
"check-config", "",
nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.checkConfig) },
}
var noCheckUpdateArg = arg{
"Don't check for updates",
"no-check-update", "",
nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.disableUpdate) },
}
var disableMemoryOptimizationArg = arg{
"Disable memory optimization",
"no-mem-optimization", "",
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
}
var verboseArg = arg{
"Enable verbose output",
"verbose", "v",
nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.verbose) },
}
var glinetArg = arg{
"Run in GL-Inet compatibility mode",
"glinet", "",
nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.glinetMode) },
}
var versionArg = arg{
"Show the version and exit",
"version", "",
nil, nil, func(o options, exec string) (effect, error) {
return func() error { fmt.Println(version()); os.Exit(0); return nil }, nil
},
func(o options) []string { return nil },
}
var helpArg = arg{
"Print this help",
"help", "",
nil, nil, func(o options, exec string) (effect, error) {
return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil
},
func(o options) []string { return nil },
}
func init() {
args = []arg{
configArg,
workDirArg,
hostArg,
portArg,
serviceArg,
logfileArg,
pidfileArg,
checkConfigArg,
noCheckUpdateArg,
disableMemoryOptimizationArg,
verboseArg,
glinetArg,
versionArg,
helpArg,
}
}
func getUsageLines(exec string, args []arg) []string {
usage := []string{
"Usage:",
"",
fmt.Sprintf("%s [options]", exec),
"",
"Options:",
}
for _, arg := range args {
val := ""
if arg.updateWithValue != nil {
val = " VALUE"
}
if arg.shortName != "" {
usage = append(usage, fmt.Sprintf(" -%s, %-30s %s",
arg.shortName,
"--"+arg.longName+val,
arg.description))
} else {
usage = append(usage, fmt.Sprintf(" %-34s %s",
"--"+arg.longName+val,
arg.description))
}
}
return usage
}
func printHelp(exec string) error {
for _, line := range getUsageLines(exec, args) {
_, err := fmt.Println(line)
if err != nil {
return err
}
}
return nil
}
func argMatches(a arg, v string) bool {
return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName)
}
func parse(exec string, ss []string) (o options, f effect, err error) {
for i := 0; i < len(ss); i++ {
v := ss[i]
knownParam := false
for _, arg := range args {
if argMatches(arg, v) {
if arg.updateWithValue != nil {
if i+1 >= len(ss) {
return o, f, fmt.Errorf("got %s without argument", v)
}
i++
o, err = arg.updateWithValue(o, ss[i])
if err != nil {
return
}
} else if arg.updateNoValue != nil {
o, err = arg.updateNoValue(o)
if err != nil {
return
}
} else if arg.effect != nil {
var eff effect
eff, err = arg.effect(o, exec)
if err != nil {
return
}
if eff != nil {
prevf := f
f = func() error {
var err error
if prevf != nil {
err = prevf()
}
if err == nil {
err = eff()
}
return err
}
}
}
knownParam = true
break
}
}
if !knownParam {
return o, f, fmt.Errorf("unknown option %v", v)
}
}
return
}
func shortestFlag(a arg) string {
if a.shortName != "" {
return "-" + a.shortName
}
return "--" + a.longName
}
func serialize(o options) []string {
ss := []string{}
for _, arg := range args {
s := arg.serialize(o)
if s != nil {
ss = append(ss, append([]string{shortestFlag(arg)}, s...)...)
}
}
return ss
}

View File

@@ -0,0 +1,251 @@
package home
import (
"fmt"
"testing"
)
func testParseOk(t *testing.T, ss ...string) options {
o, _, err := parse("", ss)
if err != nil {
t.Fatal(err.Error())
}
return o
}
func testParseErr(t *testing.T, descr string, ss ...string) {
_, _, err := parse("", ss)
if err == nil {
t.Fatalf("expected an error because %s but no error returned", descr)
}
}
func testParseParamMissing(t *testing.T, param string) {
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
}
func TestParseVerbose(t *testing.T) {
if testParseOk(t).verbose {
t.Fatal("empty is not verbose")
}
if !testParseOk(t, "-v").verbose {
t.Fatal("-v is verbose")
}
if !testParseOk(t, "--verbose").verbose {
t.Fatal("--verbose is verbose")
}
}
func TestParseConfigFilename(t *testing.T) {
if testParseOk(t).configFilename != "" {
t.Fatal("empty is no config filename")
}
if testParseOk(t, "-c", "path").configFilename != "path" {
t.Fatal("-c is config filename")
}
testParseParamMissing(t, "-c")
if testParseOk(t, "--config", "path").configFilename != "path" {
t.Fatal("--configFilename is config filename")
}
testParseParamMissing(t, "--config")
}
func TestParseWorkDir(t *testing.T) {
if testParseOk(t).workDir != "" {
t.Fatal("empty is no work dir")
}
if testParseOk(t, "-w", "path").workDir != "path" {
t.Fatal("-w is work dir")
}
testParseParamMissing(t, "-w")
if testParseOk(t, "--work-dir", "path").workDir != "path" {
t.Fatal("--work-dir is work dir")
}
testParseParamMissing(t, "--work-dir")
}
func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != "" {
t.Fatal("empty is no host")
}
if testParseOk(t, "-h", "addr").bindHost != "addr" {
t.Fatal("-h is host")
}
testParseParamMissing(t, "-h")
if testParseOk(t, "--host", "addr").bindHost != "addr" {
t.Fatal("--host is host")
}
testParseParamMissing(t, "--host")
}
func TestParseBindPort(t *testing.T) {
if testParseOk(t).bindPort != 0 {
t.Fatal("empty is port 0")
}
if testParseOk(t, "-p", "65535").bindPort != 65535 {
t.Fatal("-p is port")
}
testParseParamMissing(t, "-p")
if testParseOk(t, "--port", "65535").bindPort != 65535 {
t.Fatal("--port is port")
}
testParseParamMissing(t, "--port")
}
func TestParseBindPortBad(t *testing.T) {
testParseErr(t, "not an int", "-p", "x")
testParseErr(t, "hex not supported", "-p", "0x100")
testParseErr(t, "port negative", "-p", "-1")
testParseErr(t, "port too high", "-p", "65536")
testParseErr(t, "port too high", "-p", "4294967297") // 2^32 + 1
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
}
func TestParseLogfile(t *testing.T) {
if testParseOk(t).logFile != "" {
t.Fatal("empty is no log file")
}
if testParseOk(t, "-l", "path").logFile != "path" {
t.Fatal("-l is log file")
}
if testParseOk(t, "--logfile", "path").logFile != "path" {
t.Fatal("--logfile is log file")
}
}
func TestParsePidfile(t *testing.T) {
if testParseOk(t).pidFile != "" {
t.Fatal("empty is no pid file")
}
if testParseOk(t, "--pidfile", "path").pidFile != "path" {
t.Fatal("--pidfile is pid file")
}
}
func TestParseCheckConfig(t *testing.T) {
if testParseOk(t).checkConfig {
t.Fatal("empty is not check config")
}
if !testParseOk(t, "--check-config").checkConfig {
t.Fatal("--check-config is check config")
}
}
func TestParseDisableUpdate(t *testing.T) {
if testParseOk(t).disableUpdate {
t.Fatal("empty is not disable update")
}
if !testParseOk(t, "--no-check-update").disableUpdate {
t.Fatal("--no-check-update is disable update")
}
}
func TestParseDisableMemoryOptimization(t *testing.T) {
if testParseOk(t).disableMemoryOptimization {
t.Fatal("empty is not disable update")
}
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
t.Fatal("--no-mem-optimization is disable update")
}
}
func TestParseService(t *testing.T) {
if testParseOk(t).serviceControlAction != "" {
t.Fatal("empty is no service command")
}
if testParseOk(t, "-s", "command").serviceControlAction != "command" {
t.Fatal("-s is service command")
}
if testParseOk(t, "--service", "command").serviceControlAction != "command" {
t.Fatal("--service is service command")
}
}
func TestParseGLInet(t *testing.T) {
if testParseOk(t).glinetMode {
t.Fatal("empty is not GL-Inet mode")
}
if !testParseOk(t, "--glinet").glinetMode {
t.Fatal("--glinet is GL-Inet mode")
}
}
func TestParseUnknown(t *testing.T) {
testParseErr(t, "unknown word", "x")
testParseErr(t, "unknown short", "-x")
testParseErr(t, "unknown long", "--x")
testParseErr(t, "unknown triple", "---x")
testParseErr(t, "unknown plus", "+x")
testParseErr(t, "unknown dash", "-")
}
func testSerialize(t *testing.T, o options, ss ...string) {
result := serialize(o)
if len(result) != len(ss) {
t.Fatalf("expected %s but got %s", ss, result)
}
for i, r := range result {
if r != ss[i] {
t.Fatalf("expected %s but got %s", ss, result)
}
}
}
func TestSerializeEmpty(t *testing.T) {
testSerialize(t, options{})
}
func TestSerializeConfigFilename(t *testing.T) {
testSerialize(t, options{configFilename: "path"}, "-c", "path")
}
func TestSerializeWorkDir(t *testing.T) {
testSerialize(t, options{workDir: "path"}, "-w", "path")
}
func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: "addr"}, "-h", "addr")
}
func TestSerializeBindPort(t *testing.T) {
testSerialize(t, options{bindPort: 666}, "-p", "666")
}
func TestSerializeLogfile(t *testing.T) {
testSerialize(t, options{logFile: "path"}, "-l", "path")
}
func TestSerializePidfile(t *testing.T) {
testSerialize(t, options{pidFile: "path"}, "--pidfile", "path")
}
func TestSerializeCheckConfig(t *testing.T) {
testSerialize(t, options{checkConfig: true}, "--check-config")
}
func TestSerializeDisableUpdate(t *testing.T) {
testSerialize(t, options{disableUpdate: true}, "--no-check-update")
}
func TestSerializeService(t *testing.T) {
testSerialize(t, options{serviceControlAction: "run"}, "-s", "run")
}
func TestSerializeGLInet(t *testing.T) {
testSerialize(t, options{glinetMode: true}, "--glinet")
}
func TestSerializeDisableMemoryOptimization(t *testing.T) {
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
}
func TestSerializeMultiple(t *testing.T) {
testSerialize(t, options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
}

129
internal/home/rdns.go Normal file
View File

@@ -0,0 +1,129 @@
package home
import (
"encoding/binary"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// RDNS - module context
type RDNS struct {
dnsServer *dnsforward.Server
clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
// Contains IP addresses of clients to be resolved by rDNS
// If IP address is resolved, it stays here while it's inside Clients.
// If it's removed from Clients, this IP address will be resolved once again.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
}
// InitRDNS - create module context
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
r := RDNS{}
r.dnsServer = dnsServer
r.clients = clients
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
r.ipAddrs = cache.New(cconf)
r.ipChannel = make(chan string, 256)
go r.workerLoop()
return &r
}
// Begin - add IP address to rDNS queue
func (r *RDNS) Begin(ip string) {
now := uint64(time.Now().Unix())
expire := r.ipAddrs.Get([]byte(ip))
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
// TTL expired
}
expire = make([]byte, 8)
const ttl = 1 * 60 * 60
binary.BigEndian.PutUint64(expire, now+ttl)
_ = r.ipAddrs.Set([]byte(ip), expire)
if r.clients.Exists(ip, ClientSourceRDNS) {
return
}
log.Tracef("rDNS: adding %s", ip)
select {
case r.ipChannel <- ip:
//
default:
log.Tracef("rDNS: queue is full")
}
}
// Use rDNS to get hostname by IP address
func (r *RDNS) resolve(ip string) string {
log.Tracef("Resolving host for %s", ip)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
}
var err error
req.Question[0].Name, err = dns.ReverseAddr(ip)
if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return ""
}
resp, err := r.dnsServer.Exchange(&req)
if err != nil {
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
return ""
}
if len(resp.Answer) == 0 {
log.Debug("No answer for rDNS lookup of %s", ip)
return ""
}
ptr, ok := resp.Answer[0].(*dns.PTR)
if !ok {
log.Debug("not a PTR response for %s", ip)
return ""
}
log.Tracef("PTR response for %s: %s", ip, ptr.String())
if strings.HasSuffix(ptr.Ptr, ".") {
ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1]
}
return ptr.Ptr
}
// Wait for a signal and then synchronously resolve hostname by IP address
// Add the hostname:IP pair to "Clients" array
func (r *RDNS) workerLoop() {
for {
var ip string
ip = <-r.ipChannel
host := r.resolve(ip)
if len(host) == 0 {
continue
}
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

View File

@@ -0,0 +1,21 @@
package home
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/stretchr/testify/assert"
)
func TestResolveRDNS(t *testing.T) {
dns := &dnsforward.Server{}
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
err := dns.Prepare(conf)
assert.True(t, err == nil, "%s", err)
clients := &clientsContainer{}
rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1")
assert.True(t, r == "one.one.one.one", "%s", r)
}

517
internal/home/service.go Normal file
View File

@@ -0,0 +1,517 @@
package home
import (
"fmt"
"io/ioutil"
"os"
"runtime"
"strconv"
"strings"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/kardianos/service"
)
const (
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
serviceName = "AdGuardHome"
serviceDisplayName = "AdGuard Home service"
serviceDescription = "AdGuard Home: Network-level blocker"
)
// Represents the program that will be launched by a service or daemon
type program struct {
opts options
}
// Start should quickly start the program
func (p *program) Start(s service.Service) error {
// Start should not block. Do the actual work async.
args := p.opts
args.runningAsService = true
go run(args)
return nil
}
// Stop stops the program
func (p *program) Stop(s service.Service) error {
// Stop should not block. Return with a few seconds.
if Context.appSignalChannel == nil {
os.Exit(0)
}
Context.appSignalChannel <- syscall.SIGINT
return nil
}
// Check the service's status
// Note: on OpenWrt 'service' utility may not exist - we use our service script directly in this case.
func svcStatus(s service.Service) (service.Status, error) {
status, err := s.Status()
if err != nil && service.Platform() == "unix-systemv" {
code, err := runInitdCommand("status")
if err != nil {
return service.StatusStopped, nil
}
if code != 0 {
return service.StatusStopped, nil
}
return service.StatusRunning, nil
}
return status, err
}
// Perform an action on the service
// Note: on OpenWrt 'service' utility may not exist - we use our service script directly in this case.
func svcAction(s service.Service, action string) error {
err := service.Control(s, action)
if err != nil && service.Platform() == "unix-systemv" &&
(action == "start" || action == "stop" || action == "restart") {
_, err := runInitdCommand(action)
return err
}
return err
}
// Send SIGHUP to a process with ID taken from our pid-file
// If pid-file doesn't exist, find our PID using 'ps' command
func sendSigReload() {
if runtime.GOOS == "windows" {
log.Error("Not implemented on Windows")
return
}
pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName)
data, err := ioutil.ReadFile(pidfile)
if os.IsNotExist(err) {
code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
if err != nil || code != 0 {
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
return
}
data = []byte(psdata)
} else if err != nil {
log.Error("Can't read PID file %s: %s", pidfile, err)
return
}
parts := strings.SplitN(string(data), "\n", 2)
if len(parts) == 0 {
log.Error("Can't read PID file %s: bad value", pidfile)
return
}
pid, err := strconv.Atoi(parts[0])
if err != nil {
log.Error("Can't read PID file %s: %s", pidfile, err)
return
}
err = util.SendProcessSignal(pid, syscall.SIGHUP)
if err != nil {
log.Error("Can't send signal to PID %d: %s", pid, err)
return
}
log.Debug("Sent signal to PID %d", pid)
}
// handleServiceControlAction one of the possible control actions:
// install -- installs a service/daemon
// uninstall -- uninstalls it
// status -- prints the service status
// start -- starts the previously installed service
// stop -- stops the previously installed service
// restart - restarts the previously installed service
// run - this is a special command that is not supposed to be used directly
// it is specified when we register a service, and it indicates to the app
// that it is being run as a service/daemon.
func handleServiceControlAction(opts options) {
action := opts.serviceControlAction
log.Printf("Service control action: %s", action)
if action == "reload" {
sendSigReload()
return
}
pwd, err := os.Getwd()
if err != nil {
log.Fatal("Unable to find the path to the current directory")
}
runOpts := opts
runOpts.serviceControlAction = "run"
svcConfig := &service.Config{
Name: serviceName,
DisplayName: serviceDisplayName,
Description: serviceDescription,
WorkingDirectory: pwd,
Arguments: serialize(runOpts),
}
configureService(svcConfig)
prg := &program{runOpts}
s, err := service.New(prg, svcConfig)
if err != nil {
log.Fatal(err)
}
if action == "status" {
handleServiceStatusCommand(s)
} else if action == "run" {
err = s.Run()
if err != nil {
log.Fatalf("Failed to run service: %s", err)
}
} else if action == "install" {
initConfigFilename(opts)
initWorkingDir(opts)
handleServiceInstallCommand(s)
} else if action == "uninstall" {
handleServiceUninstallCommand(s)
} else {
err = svcAction(s, action)
if err != nil {
log.Fatal(err)
}
}
log.Printf("Action %s has been done successfully on %s", action, service.ChosenSystem().String())
}
// handleServiceStatusCommand handles service "status" command
func handleServiceStatusCommand(s service.Service) {
status, errSt := svcStatus(s)
if errSt != nil {
log.Fatalf("failed to get service status: %s", errSt)
}
switch status {
case service.StatusUnknown:
log.Printf("Service status is unknown")
case service.StatusStopped:
log.Printf("Service is stopped")
case service.StatusRunning:
log.Printf("Service is running")
}
}
// handleServiceStatusCommand handles service "install" command
func handleServiceInstallCommand(s service.Service) {
err := svcAction(s, "install")
if err != nil {
log.Fatal(err)
}
if util.IsOpenWrt() {
// On OpenWrt it is important to run enable after the service installation
// Otherwise, the service won't start on the system startup
_, err := runInitdCommand("enable")
if err != nil {
log.Fatal(err)
}
}
// Start automatically after install
err = svcAction(s, "start")
if err != nil {
log.Fatalf("Failed to start the service: %s", err)
}
log.Printf("Service has been started")
if detectFirstRun() {
log.Printf(`Almost ready!
AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.`)
printHTTPAddresses("http")
}
}
// handleServiceStatusCommand handles service "uninstall" command
func handleServiceUninstallCommand(s service.Service) {
if util.IsOpenWrt() {
// On OpenWrt it is important to run disable command first
// as it will remove the symlink
_, err := runInitdCommand("disable")
if err != nil {
log.Fatal(err)
}
}
err := svcAction(s, "uninstall")
if err != nil {
log.Fatal(err)
}
if runtime.GOOS == "darwin" {
// Removing log files on cleanup and ignore errors
err := os.Remove(launchdStdoutPath)
if err != nil && !os.IsNotExist(err) {
log.Printf("cannot remove %s", launchdStdoutPath)
}
err = os.Remove(launchdStderrPath)
if err != nil && !os.IsNotExist(err) {
log.Printf("cannot remove %s", launchdStderrPath)
}
}
}
// configureService defines additional settings of the service
func configureService(c *service.Config) {
c.Option = service.KeyValue{}
// OS X
// Redefines the launchd config file template
// The purpose is to enable stdout/stderr redirect by default
c.Option["LaunchdConfig"] = launchdConfig
// This key is used to start the job as soon as it has been loaded. For daemons this means execution at boot time, for agents execution at login.
c.Option["RunAtLoad"] = true
// POSIX
// Redirect StdErr & StdOut to files.
c.Option["LogOutput"] = true
// Use modified service file templates
c.Option["SystemdScript"] = systemdScript
c.Option["SysvScript"] = sysvScript
// On OpenWrt we're using a different type of sysvScript
if util.IsOpenWrt() {
c.Option["SysvScript"] = openWrtScript
} else if util.IsFreeBSD() {
c.Option["SysvScript"] = freeBSDScript
}
}
// runInitdCommand runs init.d service command
// returns command code or error if any
func runInitdCommand(action string) (int, error) {
confPath := "/etc/init.d/" + serviceName
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
return code, err
}
// Basically the same template as the one defined in github.com/kardianos/service
// but with two additional keys - StandardOutPath and StandardErrorPath
var launchdConfig = `<?xml version='1.0' encoding='UTF-8'?>
<!DOCTYPE plist PUBLIC "-//Apple Computer//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd" >
<plist version='1.0'>
<dict>
<key>Label</key><string>{{html .Name}}</string>
<key>ProgramArguments</key>
<array>
<string>{{html .Path}}</string>
{{range .Config.Arguments}}
<string>{{html .}}</string>
{{end}}
</array>
{{if .UserName}}<key>UserName</key><string>{{html .UserName}}</string>{{end}}
{{if .ChRoot}}<key>RootDirectory</key><string>{{html .ChRoot}}</string>{{end}}
{{if .WorkingDirectory}}<key>WorkingDirectory</key><string>{{html .WorkingDirectory}}</string>{{end}}
<key>SessionCreate</key><{{bool .SessionCreate}}/>
<key>KeepAlive</key><{{bool .KeepAlive}}/>
<key>RunAtLoad</key><{{bool .RunAtLoad}}/>
<key>Disabled</key><false/>
<key>StandardOutPath</key>
<string>` + launchdStdoutPath + `</string>
<key>StandardErrorPath</key>
<string>` + launchdStderrPath + `</string>
</dict>
</plist>
`
// Note: we should keep it in sync with the template from service_systemd_linux.go file
// Add "After=" setting for systemd service file, because we must be started only after network is online
// Set "RestartSec" to 10
const systemdScript = `[Unit]
Description={{.Description}}
ConditionFileIsExecutable={{.Path|cmdEscape}}
After=syslog.target network-online.target
[Service]
StartLimitInterval=5
StartLimitBurst=10
ExecStart={{.Path|cmdEscape}}{{range .Arguments}} {{.|cmd}}{{end}}
{{if .ChRoot}}RootDirectory={{.ChRoot|cmd}}{{end}}
{{if .WorkingDirectory}}WorkingDirectory={{.WorkingDirectory|cmdEscape}}{{end}}
{{if .UserName}}User={{.UserName}}{{end}}
{{if .ReloadSignal}}ExecReload=/bin/kill -{{.ReloadSignal}} "$MAINPID"{{end}}
{{if .PIDFile}}PIDFile={{.PIDFile|cmd}}{{end}}
{{if and .LogOutput .HasOutputFileSupport -}}
StandardOutput=file:/var/log/{{.Name}}.out
StandardError=file:/var/log/{{.Name}}.err
{{- end}}
Restart=always
RestartSec=10
EnvironmentFile=-/etc/sysconfig/{{.Name}}
[Install]
WantedBy=multi-user.target
`
// Note: we should keep it in sync with the template from service_sysv_linux.go file
// Use "ps | grep -v grep | grep $(get_pid)" because "ps PID" may not work on OpenWrt
const sysvScript = `#!/bin/sh
# For RedHat and cousins:
# chkconfig: - 99 01
# description: {{.Description}}
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
# Required-Start:
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: {{.DisplayName}}
# Description: {{.Description}}
### END INIT INFO
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
name=$(basename $(readlink -f $0))
pid_file="/var/run/$name.pid"
stdout_log="/var/log/$name.log"
stderr_log="/var/log/$name.err"
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps | grep -v grep | grep $(get_pid) > /dev/null 2>&1
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
{{if .WorkingDirectory}}cd '{{.WorkingDirectory}}'{{end}}
$cmd >> "$stdout_log" 2>> "$stderr_log" &
echo $! > "$pid_file"
if ! is_running; then
echo "Unable to start, see $stdout_log and $stderr_log"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name.."
kill $(get_pid)
for i in $(seq 1 10)
do
if ! is_running; then
break
fi
echo -n "."
sleep 1
done
echo
if is_running; then
echo "Not stopped; may still be shutting down or shutdown may have failed"
exit 1
else
echo "Stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
fi
else
echo "Not running"
fi
;;
restart)
$0 stop
if is_running; then
echo "Unable to stop, will not attempt to start"
exit 1
fi
$0 start
;;
status)
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`
// OpenWrt procd init script
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1386
const openWrtScript = `#!/bin/sh /etc/rc.common
USE_PROCD=1
START=95
STOP=01
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
name="{{.Name}}"
pid_file="/var/run/${name}.pid"
start_service() {
echo "Starting ${name}"
procd_open_instance
procd_set_param command ${cmd}
procd_set_param respawn # respawn automatically if something died
procd_set_param stdout 1 # forward stdout of the command to logd
procd_set_param stderr 1 # same for stderr
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
procd_close_instance
echo "${name} has been started"
}
stop_service() {
echo "Stopping ${name}"
}
EXTRA_COMMANDS="status"
EXTRA_HELP=" status Print the service status"
get_pid() {
cat "${pid_file}"
}
is_running() {
[ -f "${pid_file}" ] && ps | grep -v grep | grep $(get_pid) >/dev/null 2>&1
}
status() {
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
}
`
const freeBSDScript = `#!/bin/sh
# PROVIDE: {{.Name}}
# REQUIRE: networking
# KEYWORD: shutdown
. /etc/rc.subr
name="{{.Name}}"
{{.Name}}_env="IS_DAEMON=1"
{{.Name}}_user="root"
pidfile="/var/run/${name}.pid"
command="/usr/sbin/daemon"
command_args="-P ${pidfile} -r -f {{.WorkingDirectory}}/{{.Name}}"
run_rc_command "$1"
`

547
internal/home/tls.go Normal file
View File

@@ -0,0 +1,547 @@
package home
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"reflect"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
var tlsWebHandlersRegistered = false
// TLSMod - TLS module object
type TLSMod struct {
certLastMod time.Time // last modification time of the certificate file
conf tlsConfigSettings
confLock sync.Mutex
status tlsConfigStatus
}
// Create TLS module
func tlsCreate(conf tlsConfigSettings) *TLSMod {
t := &TLSMod{}
t.conf = conf
if t.conf.Enabled {
if !t.load() {
// Something is not valid - return an empty TLS config
return &TLSMod{conf: tlsConfigSettings{
Enabled: conf.Enabled,
ServerName: conf.ServerName,
PortHTTPS: conf.PortHTTPS,
PortDNSOverTLS: conf.PortDNSOverTLS,
PortDNSOverQUIC: conf.PortDNSOverQUIC,
AllowUnencryptedDOH: conf.AllowUnencryptedDOH,
}}
}
t.setCertFileTime()
}
return t
}
func (t *TLSMod) load() bool {
if !tlsLoadConfig(&t.conf, &t.status) {
log.Error("failed to load TLS config: %s", t.status.WarningValidation)
return false
}
// validate current TLS config and update warnings (it could have been loaded from file)
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
if !data.ValidPair {
log.Error("failed to validate certificate: %s", data.WarningValidation)
return false
}
t.status = data
return true
}
// Close - close module
func (t *TLSMod) Close() {
}
// WriteDiskConfig - write config
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
t.confLock.Lock()
*conf = t.conf
t.confLock.Unlock()
}
func (t *TLSMod) setCertFileTime() {
if len(t.conf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(t.conf.CertificatePath)
if err != nil {
log.Error("TLS: %s", err)
return
}
t.certLastMod = fi.ModTime().UTC()
}
// Start - start the module
func (t *TLSMod) Start() {
if !tlsWebHandlersRegistered {
tlsWebHandlersRegistered = true
t.registerWebHandlers()
}
t.confLock.Lock()
tlsConf := t.conf
t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf)
}
// Reload - reload certificate file
func (t *TLSMod) Reload() {
t.confLock.Lock()
tlsConf := t.conf
t.confLock.Unlock()
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(tlsConf.CertificatePath)
if err != nil {
log.Error("TLS: %s", err)
return
}
if fi.ModTime().UTC().Equal(t.certLastMod) {
log.Debug("TLS: certificate file isn't modified")
return
}
log.Debug("TLS: certificate file is modified")
t.confLock.Lock()
r := t.load()
t.confLock.Unlock()
if !r {
return
}
t.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer()
Context.web.TLSConfigChanged(tlsConf)
}
// Set certificate and private key data
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
tls.CertificateChainData = []byte(tls.CertificateChain)
tls.PrivateKeyData = []byte(tls.PrivateKey)
var err error
if tls.CertificatePath != "" {
if tls.CertificateChain != "" {
status.WarningValidation = "certificate data and file can't be set together"
return false
}
tls.CertificateChainData, err = ioutil.ReadFile(tls.CertificatePath)
if err != nil {
status.WarningValidation = err.Error()
return false
}
status.ValidCert = true
}
if tls.PrivateKeyPath != "" {
if tls.PrivateKey != "" {
status.WarningValidation = "private key data and file can't be set together"
return false
}
tls.PrivateKeyData, err = ioutil.ReadFile(tls.PrivateKeyPath)
if err != nil {
status.WarningValidation = err.Error()
return false
}
status.ValidKey = true
}
return true
}
type tlsConfigStatus struct {
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
// key status
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
// is usable? set by validator
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
// warnings
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
}
// field ordering is important -- yaml fields will mirror ordering from here
type tlsConfig struct {
tlsConfigSettings `json:",inline"`
tlsConfigStatus `json:",inline"`
}
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) {
t.confLock.Lock()
data := tlsConfig{
tlsConfigSettings: t.conf,
tlsConfigStatus: t.status,
}
t.confLock.Unlock()
marshalTLS(w, data)
}
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if !WebCheckPortAvailable(setts.PortHTTPS) {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
return
}
status := tlsConfigStatus{}
if tlsLoadConfig(&setts, &status) {
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
}
data := tlsConfig{
tlsConfigSettings: setts,
tlsConfigStatus: status,
}
marshalTLS(w, data)
}
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
data, err := unmarshalTLS(r)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if !WebCheckPortAvailable(data.PortHTTPS) {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return
}
status := tlsConfigStatus{}
if !tlsLoadConfig(&data, &status) {
data2 := tlsConfig{
tlsConfigSettings: data,
tlsConfigStatus: t.status,
}
marshalTLS(w, data2)
return
}
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
restartHTTPS := false
t.confLock.Lock()
if !reflect.DeepEqual(t.conf, data) {
log.Printf("tls config settings have changed, will restart HTTPS server")
restartHTTPS = true
}
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
t.conf.Enabled = data.Enabled
t.conf.ServerName = data.ServerName
t.conf.ForceHTTPS = data.ForceHTTPS
t.conf.PortHTTPS = data.PortHTTPS
t.conf.PortDNSOverTLS = data.PortDNSOverTLS
t.conf.PortDNSOverQUIC = data.PortDNSOverQUIC
t.conf.CertificateChain = data.CertificateChain
t.conf.CertificatePath = data.CertificatePath
t.conf.CertificateChainData = data.CertificateChainData
t.conf.PrivateKey = data.PrivateKey
t.conf.PrivateKeyPath = data.PrivateKeyPath
t.conf.PrivateKeyData = data.PrivateKeyData
t.status = status
t.confLock.Unlock()
t.setCertFileTime()
onConfigModified()
err = reconfigureDNSServer()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
data2 := tlsConfig{
tlsConfigSettings: data,
tlsConfigStatus: t.status,
}
marshalTLS(w, data2)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(data)
}()
}
}
func verifyCertChain(data *tlsConfigStatus, certChain string, serverName string) error {
log.Tracef("TLS: got certificate: %d bytes", len(certChain))
// now do a more extended validation
var certs []*pem.Block // PEM-encoded certificates
var skippedBytes []string // skipped bytes
pemblock := []byte(certChain)
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}
if decoded.Type == "CERTIFICATE" {
certs = append(certs, decoded)
} else {
// ignore "this result of append is never used" warning
// nolint
skippedBytes = append(skippedBytes, decoded.Type)
}
}
var parsedCerts []*x509.Certificate
for _, cert := range certs {
parsed, err := x509.ParseCertificate(cert.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
return errors.New(data.WarningValidation)
}
parsedCerts = append(parsedCerts, parsed)
}
if len(parsedCerts) == 0 {
data.WarningValidation = fmt.Sprintf("You have specified an empty certificate")
return errors.New(data.WarningValidation)
}
data.ValidCert = true
// spew.Dump(parsedCerts)
opts := x509.VerifyOptions{
DNSName: serverName,
Roots: Context.tlsRoots,
}
log.Printf("number of certs - %d", len(parsedCerts))
if len(parsedCerts) > 1 {
// set up an intermediate
pool := x509.NewCertPool()
for _, cert := range parsedCerts[1:] {
log.Printf("got an intermediate cert")
pool.AddCert(cert)
}
opts.Intermediates = pool
}
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem
mainCert := parsedCerts[0]
_, err := mainCert.Verify(opts)
if err != nil {
// let self-signed certs through
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err)
} else {
data.ValidChain = true
}
// spew.Dump(chains)
// update status
if mainCert != nil {
notAfter := mainCert.NotAfter
data.Subject = mainCert.Subject.String()
data.Issuer = mainCert.Issuer.String()
data.NotAfter = notAfter
data.NotBefore = mainCert.NotBefore
data.DNSNames = mainCert.DNSNames
}
return nil
}
func validatePkey(data *tlsConfigStatus, pkey string) error {
// now do a more extended validation
var key *pem.Block // PEM-encoded certificates
var skippedBytes []string // skipped bytes
// go through all pem blocks, but take first valid pem block and drop the rest
pemblock := []byte(pkey)
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}
if decoded.Type == "PRIVATE KEY" || strings.HasSuffix(decoded.Type, " PRIVATE KEY") {
key = decoded
break
} else {
// ignore "this result of append is never used"
// nolint
skippedBytes = append(skippedBytes, decoded.Type)
}
}
if key == nil {
data.WarningValidation = "No valid keys were found"
return errors.New(data.WarningValidation)
}
// parse the decoded key
_, keytype, err := parsePrivateKey(key.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
return errors.New(data.WarningValidation)
}
data.ValidKey = true
data.KeyType = keytype
return nil
}
// Process certificate data and its private key.
// All parameters are optional.
// On error, return partially set object
// with 'WarningValidation' field containing error description.
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus {
var data tlsConfigStatus
// check only public certificate separately from the key
if certChain != "" {
if verifyCertChain(&data, certChain, serverName) != nil {
return data
}
}
// validate private key (right now the only validation possible is just parsing it)
if pkey != "" {
if validatePkey(&data, pkey) != nil {
return data
}
}
// if both are set, validate both in unison
if pkey != "" && certChain != "" {
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
if err != nil {
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err)
return data
}
data.ValidPair = true
}
return data
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, "RSA", nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey:
return key, "RSA", nil
case *ecdsa.PrivateKey:
return key, "ECDSA", nil
default:
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, "ECDSA", nil
}
return nil, "", errors.New("tls: failed to parse private key")
}
// unmarshalTLS handles base64-encoded certificates transparently
func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) {
data := tlsConfigSettings{}
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
}
if data.CertificateChain != "" {
certPEM, err := base64.StdEncoding.DecodeString(data.CertificateChain)
if err != nil {
return data, errorx.Decorate(err, "Failed to base64-decode certificate chain")
}
data.CertificateChain = string(certPEM)
if data.CertificatePath != "" {
return data, fmt.Errorf("certificate data and file can't be set together")
}
}
if data.PrivateKey != "" {
keyPEM, err := base64.StdEncoding.DecodeString(data.PrivateKey)
if err != nil {
return data, errorx.Decorate(err, "Failed to base64-decode private key")
}
data.PrivateKey = string(keyPEM)
if data.PrivateKeyPath != "" {
return data, fmt.Errorf("private key data and file can't be set together")
}
}
return data, nil
}
func marshalTLS(w http.ResponseWriter, data tlsConfig) {
w.Header().Set("Content-Type", "application/json")
if data.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
data.CertificateChain = encoded
}
if data.PrivateKey != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.PrivateKey))
data.PrivateKey = encoded
}
err := json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err)
return
}
}
// registerWebHandlers registers HTTP handlers for TLS configuration
func (t *TLSMod) registerWebHandlers() {
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
}

437
internal/home/upgrade.go Normal file
View File

@@ -0,0 +1,437 @@
package home
import (
"fmt"
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/crypto/bcrypt"
yaml "gopkg.in/yaml.v2"
)
const currentSchemaVersion = 7 // used for upgrading from old configs to new config
// Performs necessary upgrade operations if needed
func upgradeConfig() error {
// read a config file into an interface map, so we can manipulate values without losing any
diskConfig := map[string]interface{}{}
body, err := readConfigFile()
if err != nil {
return err
}
err = yaml.Unmarshal(body, &diskConfig)
if err != nil {
log.Printf("Couldn't parse config file: %s", err)
return err
}
schemaVersionInterface, ok := diskConfig["schema_version"]
log.Tracef("got schema version %v", schemaVersionInterface)
if !ok {
// no schema version, set it to 0
schemaVersionInterface = 0
}
schemaVersion, ok := schemaVersionInterface.(int)
if !ok {
err = fmt.Errorf("configuration file contains non-integer schema_version, abort")
log.Println(err)
return err
}
if schemaVersion == currentSchemaVersion {
// do nothing
return nil
}
return upgradeConfigSchema(schemaVersion, &diskConfig)
}
// Upgrade from oldVersion to newVersion
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
switch oldVersion {
case 0:
err := upgradeSchema0to1(diskConfig)
if err != nil {
return err
}
fallthrough
case 1:
err := upgradeSchema1to2(diskConfig)
if err != nil {
return err
}
fallthrough
case 2:
err := upgradeSchema2to3(diskConfig)
if err != nil {
return err
}
fallthrough
case 3:
err := upgradeSchema3to4(diskConfig)
if err != nil {
return err
}
fallthrough
case 4:
err := upgradeSchema4to5(diskConfig)
if err != nil {
return err
}
fallthrough
case 5:
err := upgradeSchema5to6(diskConfig)
if err != nil {
return err
}
fallthrough
case 6:
err := upgradeSchema6to7(diskConfig)
if err != nil {
return err
}
default:
err := fmt.Errorf("configuration file contains unknown schema_version, abort")
log.Println(err)
return err
}
configFile := config.getConfigFilename()
body, err := yaml.Marshal(diskConfig)
if err != nil {
log.Printf("Couldn't generate YAML file: %s", err)
return err
}
config.fileData = body
err = file.SafeWrite(configFile, body)
if err != nil {
log.Printf("Couldn't save YAML config: %s", err)
return err
}
return nil
}
// The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath)
if err != nil {
log.Printf("Cannot remove %s due to %s", dnsFilterPath, err)
// not fatal, move on
}
}
(*diskConfig)["schema_version"] = 1
return nil
}
// Second schema upgrade:
// coredns is now dns in config
// delete 'Corefile', since we don't use that anymore
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
coreFilePath := filepath.Join(Context.workDir, "Corefile")
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
err = os.Remove(coreFilePath)
if err != nil {
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
// not fatal, move on
}
}
if _, ok := (*diskConfig)["dns"]; !ok {
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
delete((*diskConfig), "coredns")
}
(*diskConfig)["schema_version"] = 2
return nil
}
// Third schema upgrade:
// Bootstrap DNS becomes an array
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
// Let's read dns configuration from diskConfig
dnsConfig, ok := (*diskConfig)["dns"]
if !ok {
return fmt.Errorf("no DNS configuration in config file")
}
// Convert interface{} to map[string]interface{}
newDNSConfig := make(map[string]interface{})
switch v := dnsConfig.(type) {
case map[interface{}]interface{}:
for k, v := range v {
newDNSConfig[fmt.Sprint(k)] = v
}
default:
return fmt.Errorf("DNS configuration is not a map")
}
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok {
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
(newDNSConfig)["bootstrap_dns"] = newBootstrapConfig
(*diskConfig)["dns"] = newDNSConfig
} else {
return fmt.Errorf("no bootstrap DNS in DNS config")
}
// Bump schema version
(*diskConfig)["schema_version"] = 3
return nil
}
// Add use_global_blocked_services=true setting for existing "clients" array
func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 4
clients, ok := (*diskConfig)["clients"]
if !ok {
return nil
}
switch arr := clients.(type) {
case []interface{}:
for i := range arr {
switch c := arr[i].(type) {
case map[interface{}]interface{}:
c["use_global_blocked_services"] = true
default:
continue
}
}
default:
return nil
}
return nil
}
// Replace "auth_name", "auth_pass" string settings with an array:
// users:
// - name: "..."
// password: "..."
// ...
func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 5
name, ok := (*diskConfig)["auth_name"]
if !ok {
return nil
}
nameStr, ok := name.(string)
if !ok {
log.Fatal("Please use double quotes in your user name in \"auth_name\" and restart AdGuardHome")
return nil
}
pass, ok := (*diskConfig)["auth_pass"]
if !ok {
return nil
}
passStr, ok := pass.(string)
if !ok {
log.Fatal("Please use double quotes in your password in \"auth_pass\" and restart AdGuardHome")
return nil
}
if len(nameStr) == 0 {
return nil
}
hash, err := bcrypt.GenerateFromPassword([]byte(passStr), bcrypt.DefaultCost)
if err != nil {
log.Fatalf("Can't use password \"%s\": bcrypt.GenerateFromPassword: %s", passStr, err)
return nil
}
u := User{
Name: nameStr,
PasswordHash: string(hash),
}
users := []User{u}
(*diskConfig)["users"] = users
return nil
}
// clients:
// ...
// ip: 127.0.0.1
// mac: ...
//
// ->
//
// clients:
// ...
// ids:
// - 127.0.0.1
// - ...
func upgradeSchema5to6(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", util.FuncName())
(*diskConfig)["schema_version"] = 6
clients, ok := (*diskConfig)["clients"]
if !ok {
return nil
}
switch arr := clients.(type) {
case []interface{}:
for i := range arr {
switch c := arr[i].(type) {
case map[interface{}]interface{}:
_ip, ok := c["ip"]
ids := []string{}
if ok {
ip, ok := _ip.(string)
if !ok {
log.Fatalf("client.ip is not a string: %v", _ip)
return nil
}
if len(ip) != 0 {
ids = append(ids, ip)
}
}
_mac, ok := c["mac"]
if ok {
mac, ok := _mac.(string)
if !ok {
log.Fatalf("client.mac is not a string: %v", _mac)
return nil
}
if len(mac) != 0 {
ids = append(ids, mac)
}
}
c["ids"] = ids
default:
continue
}
}
default:
return nil
}
return nil
}
// dhcp:
// enabled: false
// interface_name: vboxnet0
// gateway_ip: 192.168.56.1
// ...
//
// ->
//
// dhcp:
// enabled: false
// interface_name: vboxnet0
// dhcpv4:
// gateway_ip: 192.168.56.1
// ...
func upgradeSchema6to7(diskConfig *map[string]interface{}) error {
log.Printf("Upgrade yaml: 6 to 7")
(*diskConfig)["schema_version"] = 7
_dhcp, ok := (*diskConfig)["dhcp"]
if !ok {
return nil
}
switch dhcp := _dhcp.(type) {
case map[interface{}]interface{}:
dhcpv4 := map[string]interface{}{}
val, ok := dhcp["gateway_ip"].(string)
if !ok {
log.Fatalf("expecting dhcp.%s to be a string", "gateway_ip")
return nil
}
dhcpv4["gateway_ip"] = val
delete(dhcp, "gateway_ip")
val, ok = dhcp["subnet_mask"].(string)
if !ok {
log.Fatalf("expecting dhcp.%s to be a string", "subnet_mask")
return nil
}
dhcpv4["subnet_mask"] = val
delete(dhcp, "subnet_mask")
val, ok = dhcp["range_start"].(string)
if !ok {
log.Fatalf("expecting dhcp.%s to be a string", "range_start")
return nil
}
dhcpv4["range_start"] = val
delete(dhcp, "range_start")
val, ok = dhcp["range_end"].(string)
if !ok {
log.Fatalf("expecting dhcp.%s to be a string", "range_end")
return nil
}
dhcpv4["range_end"] = val
delete(dhcp, "range_end")
intVal, ok := dhcp["lease_duration"].(int)
if !ok {
log.Fatalf("expecting dhcp.%s to be an integer", "lease_duration")
return nil
}
dhcpv4["lease_duration"] = intVal
delete(dhcp, "lease_duration")
intVal, ok = dhcp["icmp_timeout_msec"].(int)
if !ok {
log.Fatalf("expecting dhcp.%s to be an integer", "icmp_timeout_msec")
return nil
}
dhcpv4["icmp_timeout_msec"] = intVal
delete(dhcp, "icmp_timeout_msec")
dhcp["dhcpv4"] = dhcpv4
default:
return nil
}
return nil
}

View File

@@ -0,0 +1,230 @@
package home
import (
"fmt"
"testing"
)
func TestUpgrade1to2(t *testing.T) {
// let's create test config for 1 schema version
diskConfig := createTestDiskConfig(1)
// update config
err := upgradeSchema1to2(&diskConfig)
if err != nil {
t.Fatalf("Can't upgrade schema version from 1 to 2")
}
// ensure that schema version was bumped
compareSchemaVersion(t, diskConfig["schema_version"], 2)
// old coredns entry should be removed
_, ok := diskConfig["coredns"]
if ok {
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
}
// pull out new dns config
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
}
// cast dns configurations to maps and compare them
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
newDNSConfig := castInterfaceToMap(t, dnsMap)
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
// exclude dns config and schema version from disk config comparison
oldExcludedEntries := []string{"coredns", "schema_version"}
newExcludedEntries := []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(1)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
}
func TestUpgrade2to3(t *testing.T) {
// let's create test config
diskConfig := createTestDiskConfig(2)
// upgrade schema from 2 to 3
err := upgradeSchema2to3(&diskConfig)
if err != nil {
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
}
// check new schema version
compareSchemaVersion(t, diskConfig["schema_version"], 3)
// pull out new dns configuration
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No dns config in new configuration")
}
// cast dns configuration to map
newDNSConfig := castInterfaceToMap(t, dnsMap)
// check if bootstrap DNS becomes an array
bootstrapDNS := newDNSConfig["bootstrap_dns"]
switch v := bootstrapDNS.(type) {
case []string:
if len(v) != 1 {
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
}
if v[0] != "8.8.8.8:53" {
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
}
default:
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
}
// exclude bootstrap DNS from DNS configs comparison
excludedEntries := []string{"bootstrap_dns"}
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
// excluded dns config and schema version from disk config comparison
excludedEntries = []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(2)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
}
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
newConfig = make(map[string]interface{})
switch v := oldConfig.(type) {
case map[interface{}]interface{}:
for key, value := range v {
newConfig[fmt.Sprint(key)] = value
}
case map[string]interface{}:
for key, value := range v {
newConfig[key] = value
}
default:
t.Fatalf("DNS configuration is not a map")
}
return
}
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
for _, k := range oldKey {
delete(*oldConfig, k)
}
for _, k := range newKey {
delete(*newConfig, k)
}
compareConfigs(t, oldConfig, newConfig)
}
// compares configs before and after schema upgrade
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
if len(*oldConfig) != len(*newConfig) {
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
}
// Check old and new entries
for k, v := range *newConfig {
switch value := v.(type) {
case string:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
}
case int:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
}
case []string:
for i, line := range value {
if len((*oldConfig)[k].([]string)) != len(value) {
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
}
if (*oldConfig)[k].([]string)[i] != line {
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
}
}
case bool:
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
default:
t.Fatalf("uknown data type for %s: %T", k, value)
}
}
}
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
switch v := newSchemaVersion.(type) {
case int:
if v != schemaVersion {
t.Fatalf("Wrong schema version in new config file")
}
default:
t.Fatalf("Schema version is not an integer after update")
}
}
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,
},
{
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RulesCount: 200,
},
}
diskConfig["user_rules"] = []string{}
diskConfig["schema_version"] = schemaVersion
diskConfig["bind_host"] = "0.0.0.0"
diskConfig["bind_port"] = 80
diskConfig["auth_name"] = "name"
diskConfig["auth_pass"] = "pass"
dnsConfig := createTestDNSConfig(schemaVersion)
if schemaVersion > 1 {
diskConfig["dns"] = dnsConfig
} else {
diskConfig["coredns"] = dnsConfig
}
return diskConfig
}
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
dnsConfig := make(map[interface{}]interface{})
dnsConfig["port"] = 53
dnsConfig["blocked_response_ttl"] = 10
dnsConfig["querylog_enabled"] = true
dnsConfig["ratelimit"] = 20
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
if schemaVersion > 2 {
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
}
dnsConfig["parental_sensitivity"] = 13
dnsConfig["ratelimit_whitelist"] = []string{}
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
dnsConfig["filtering_enabled"] = true
dnsConfig["refuse_any"] = true
dnsConfig["parental_enabled"] = true
dnsConfig["bind_host"] = "0.0.0.0"
dnsConfig["protection_enabled"] = true
dnsConfig["safesearch_enabled"] = true
dnsConfig["safebrowsing_enabled"] = true
return dnsConfig
}

209
internal/home/web.go Normal file
View File

@@ -0,0 +1,209 @@
package home
import (
"context"
"crypto/tls"
golog "log"
"net"
"net/http"
"strconv"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/gobuffalo/packr"
)
type WebConfig struct {
firstRun bool
BindHost string
BindPort int
PortHTTPS int
}
// HTTPSServer - HTTPS Server
type HTTPSServer struct {
server *http.Server
cond *sync.Cond
condLock sync.Mutex
shutdown bool // if TRUE, don't restart the server
enabled bool
cert tls.Certificate
}
// Web - module object
type Web struct {
conf *WebConfig
forceHTTPS bool
portHTTPS int
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
errLogger *golog.Logger
}
// Proxy between Go's "log" and "golibs/log"
type logWriter struct {
}
// HTTP server calls this function to log an error
func (w *logWriter) Write(p []byte) (int, error) {
log.Debug("Web: %s", string(p))
return 0, nil
}
// CreateWeb - create module
func CreateWeb(conf *WebConfig) *Web {
log.Info("Initialize web module")
w := Web{}
w.conf = conf
lw := logWriter{}
w.errLogger = golog.New(&lw, "", 0)
// Initialize and run the admin Web interface
box := packr.NewBox("../../build/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
// add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
w.registerInstallHandlers()
} else {
registerControlHandlers()
}
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
return &w
}
// WebCheckPortAvailable - check if port is available
// BUT: if we are already using this port, no need
func WebCheckPortAvailable(port int) bool {
alreadyRunning := false
if Context.web.httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err := util.CheckPortAvailable(config.BindHost, port)
if err != nil {
return false
}
}
return true
}
// TLSConfigChanged - called when TLS configuration has changed
func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
log.Debug("Web: applying new TLS configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
web.portHTTPS = tlsConf.PortHTTPS
enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 &&
len(tlsConf.PrivateKeyData) != 0 &&
len(tlsConf.CertificateChainData) != 0
var cert tls.Certificate
var err error
if enabled {
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
if err != nil {
log.Fatal(err)
}
}
web.httpsServer.cond.L.Lock()
if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO())
}
web.httpsServer.enabled = enabled
web.httpsServer.cert = cert
web.httpsServer.cond.Broadcast()
web.httpsServer.cond.L.Unlock()
}
// Start - start serving HTTP requests
func (web *Web) Start() {
// for https, we have a separate goroutine loop
go web.tlsServerLoop()
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
web.httpServer = &http.Server{
ErrorLog: web.errLogger,
Addr: address,
}
err := web.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
}
}
// Close - stop HTTP server, possibly waiting for all active connections to be closed
func (web *Web) Close() {
log.Info("Stopping HTTP server...")
web.httpsServer.cond.L.Lock()
web.httpsServer.shutdown = true
web.httpsServer.cond.L.Unlock()
if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO())
}
if web.httpServer != nil {
_ = web.httpServer.Shutdown(context.TODO())
}
log.Info("Stopped HTTP server")
}
func (web *Web) tlsServerLoop() {
for {
web.httpsServer.cond.L.Lock()
if web.httpsServer.shutdown {
web.httpsServer.cond.L.Unlock()
break
}
// this mechanism doesn't let us through until all conditions are met
for !web.httpsServer.enabled { // sleep until necessary data is supplied
web.httpsServer.cond.Wait()
if web.httpsServer.shutdown {
web.httpsServer.cond.L.Unlock()
return
}
}
web.httpsServer.cond.L.Unlock()
// prepare HTTPS server
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
web.httpsServer.server = &http.Server{
ErrorLog: web.errLogger,
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCiphers,
},
}
printHTTPAddresses("https")
err := web.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
}
}

237
internal/home/whois.go Normal file
View File

@@ -0,0 +1,237 @@
package home
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
const (
defaultServer = "whois.arin.net"
defaultPort = "43"
maxValueLength = 250
whoisTTL = 1 * 60 * 60 // 1 hour
)
// Whois - module context
type Whois struct {
clients *clientsContainer
ipChan chan string
timeoutMsec uint
// Contains IP addresses of clients
// An active IP address is resolved once again after it expires.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
}
// Create module context
func initWhois(clients *clientsContainer) *Whois {
w := Whois{}
w.timeoutMsec = 5000
w.clients = clients
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
w.ipAddrs = cache.New(cconf)
w.ipChan = make(chan string, 255)
go w.workerLoop()
return &w
}
// If the value is too large - cut it and append "..."
func trimValue(s string) string {
if len(s) <= maxValueLength {
return s
}
return s[:maxValueLength-3] + "..."
}
// Parse plain-text data from the response
func whoisParse(data string) map[string]string {
m := map[string]string{}
descr := ""
netname := ""
for len(data) != 0 {
ln := util.SplitNext(&data, '\n')
if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' {
continue
}
kv := strings.SplitN(ln, ":", 2)
if len(kv) != 2 {
continue
}
k := strings.TrimSpace(kv[0])
k = strings.ToLower(k)
v := strings.TrimSpace(kv[1])
switch k {
case "org-name":
m["orgname"] = trimValue(v)
case "orgname":
fallthrough
case "city":
fallthrough
case "country":
m[k] = trimValue(v)
case "descr":
if len(descr) == 0 {
descr = v
}
case "netname":
netname = v
case "whois": // "whois: whois.arin.net"
m["whois"] = v
case "referralserver": // "ReferralServer: whois://whois.ripe.net"
if strings.HasPrefix(v, "whois://") {
m["whois"] = v[len("whois://"):]
}
}
}
// descr or netname -> orgname
_, ok := m["orgname"]
if !ok && len(descr) != 0 {
m["orgname"] = trimValue(descr)
} else if !ok && len(netname) != 0 {
m["orgname"] = trimValue(netname)
}
return m
}
// Send request to a server and receive the response
func (w *Whois) query(target string, serverAddr string) (string, error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
if err != nil {
return "", err
}
defer conn.Close()
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n"))
if err != nil {
return "", err
}
data, err := ioutil.ReadAll(conn)
if err != nil {
return "", err
}
return string(data), nil
}
// Query WHOIS servers (handle redirects)
func (w *Whois) queryAll(target string) (string, error) {
server := net.JoinHostPort(defaultServer, defaultPort)
const maxRedirects = 5
for i := 0; i != maxRedirects; i++ {
resp, err := w.query(target, server)
if err != nil {
return "", err
}
log.Debug("Whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
m := whoisParse(resp)
redir, ok := m["whois"]
if !ok {
return resp, nil
}
redir = strings.ToLower(redir)
_, _, err = net.SplitHostPort(redir)
if err != nil {
server = net.JoinHostPort(redir, defaultPort)
} else {
server = redir
}
log.Debug("Whois: redirected to %s IP:%s", redir, target)
}
return "", fmt.Errorf("Whois: redirect loop")
}
// Request WHOIS information
func (w *Whois) process(ip string) [][]string {
data := [][]string{}
resp, err := w.queryAll(ip)
if err != nil {
log.Debug("Whois: error: %s IP:%s", err, ip)
return data
}
log.Debug("Whois: IP:%s response: %d bytes", ip, len(resp))
m := whoisParse(resp)
keys := []string{"orgname", "country", "city"}
for _, k := range keys {
v, found := m[k]
if !found {
continue
}
pair := []string{k, v}
data = append(data, pair)
}
return data
}
// Begin - begin requesting WHOIS info
func (w *Whois) Begin(ip string) {
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip))
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
// TTL expired
}
expire = make([]byte, 8)
binary.BigEndian.PutUint64(expire, now+whoisTTL)
_ = w.ipAddrs.Set([]byte(ip), expire)
log.Debug("Whois: adding %s", ip)
select {
case w.ipChan <- ip:
//
default:
log.Debug("Whois: queue is full")
}
}
// Get IP address from channel; get WHOIS info; associate info with a client
func (w *Whois) workerLoop() {
for {
var ip string
ip = <-w.ipChan
info := w.process(ip)
if len(info) == 0 {
continue
}
w.clients.SetWhoisInfo(ip, info)
}
}

View File

@@ -0,0 +1,28 @@
package home
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/stretchr/testify/assert"
)
func prepareTestDNSServer() error {
config.DNS.Port = 1234
Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
return Context.dnsServer.Prepare(conf)
}
func TestWhois(t *testing.T) {
assert.Nil(t, prepareTestDNSServer())
w := Whois{timeoutMsec: 5000}
resp, err := w.queryAll("8.8.8.8")
assert.Nil(t, err)
m := whoisParse(resp)
assert.Equal(t, "Google LLC", m["orgname"])
assert.Equal(t, "US", m["country"])
assert.Equal(t, "Mountain View", m["city"])
}