Pull request: * all: move internal Go packages to internal/
Merge in DNS/adguard-home from 2234-move-to-internal to master Squashed commit of the following: commit d26a288cabeac86f9483fab307677b1027c78524 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Oct 30 12:44:18 2020 +0300 * all: move internal Go packages to internal/ Closes #2234.
This commit is contained in:
536
internal/home/auth.go
Normal file
536
internal/home/auth.go
Normal file
@@ -0,0 +1,536 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const cookieTTL = 365 * 24 // in hours
|
||||
const sessionCookieName = "agh_session"
|
||||
|
||||
type session struct {
|
||||
userName string
|
||||
expire uint32 // expiration time (in seconds)
|
||||
}
|
||||
|
||||
/*
|
||||
expire byte[4]
|
||||
name_len byte[2]
|
||||
name byte[]
|
||||
*/
|
||||
func (s *session) serialize() []byte {
|
||||
var data []byte
|
||||
data = make([]byte, 4+2+len(s.userName))
|
||||
binary.BigEndian.PutUint32(data[0:4], s.expire)
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
|
||||
copy(data[6:], []byte(s.userName))
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *session) deserialize(data []byte) bool {
|
||||
if len(data) < 4+2 {
|
||||
return false
|
||||
}
|
||||
s.expire = binary.BigEndian.Uint32(data[0:4])
|
||||
nameLen := binary.BigEndian.Uint16(data[4:6])
|
||||
data = data[6:]
|
||||
|
||||
if len(data) < int(nameLen) {
|
||||
return false
|
||||
}
|
||||
s.userName = string(data)
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
sessions map[string]*session // session name -> session data
|
||||
lock sync.Mutex
|
||||
users []User
|
||||
sessionTTL uint32 // in seconds
|
||||
}
|
||||
|
||||
// User object
|
||||
type User struct {
|
||||
Name string `yaml:"name"`
|
||||
PasswordHash string `yaml:"password"` // bcrypt hash
|
||||
}
|
||||
|
||||
// InitAuth - create a global object
|
||||
func InitAuth(dbFilename string, users []User, sessionTTL uint32) *Auth {
|
||||
log.Info("Initializing auth module: %s", dbFilename)
|
||||
|
||||
a := Auth{}
|
||||
a.sessionTTL = sessionTTL
|
||||
a.sessions = make(map[string]*session)
|
||||
rand.Seed(time.Now().UTC().Unix())
|
||||
var err error
|
||||
a.db, err = bbolt.Open(dbFilename, 0644, nil)
|
||||
if err != nil {
|
||||
log.Error("Auth: open DB: %s: %s", dbFilename, err)
|
||||
if err.Error() == "invalid argument" {
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#limitations")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
a.loadSessions()
|
||||
a.users = users
|
||||
log.Info("Auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
|
||||
return &a
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (a *Auth) Close() {
|
||||
_ = a.db.Close()
|
||||
}
|
||||
|
||||
func bucketName() []byte {
|
||||
return []byte("sessions-2")
|
||||
}
|
||||
|
||||
// load sessions from file, remove expired sessions
|
||||
func (a *Auth) loadSessions() {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Begin: %s", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
if tx.Bucket([]byte("sessions")) != nil {
|
||||
_ = tx.DeleteBucket([]byte("sessions"))
|
||||
removed = 1
|
||||
}
|
||||
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
forEach := func(k, v []byte) error {
|
||||
s := session{}
|
||||
if !s.deserialize(v) || s.expire <= now {
|
||||
err = bkt.Delete(k)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Delete: %s", err)
|
||||
} else {
|
||||
removed++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
a.sessions[hex.EncodeToString(k)] = &s
|
||||
return nil
|
||||
}
|
||||
_ = bkt.ForEach(forEach)
|
||||
if removed != 0 {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("bolt.Commit(): %s", err)
|
||||
}
|
||||
}
|
||||
log.Debug("Auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
func (a *Auth) addSession(data []byte, s *session) {
|
||||
name := hex.EncodeToString(data)
|
||||
a.lock.Lock()
|
||||
a.sessions[name] = s
|
||||
a.lock.Unlock()
|
||||
if a.storeSession(data, s) {
|
||||
log.Debug("Auth: created session %s: expire=%d", name, s.expire)
|
||||
}
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
func (a *Auth) storeSession(data []byte, s *session) bool {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Begin: %s", err)
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists(bucketName())
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err)
|
||||
return false
|
||||
}
|
||||
err = bkt.Put(data, s.serialize())
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Put: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Commit: %s", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// remove session from file
|
||||
func (a *Auth) removeSession(sess []byte) {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Begin: %s", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
log.Error("Auth: bbolt.Bucket")
|
||||
return
|
||||
}
|
||||
err = bkt.Delete(sess)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Put: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Commit: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Auth: removed session from DB")
|
||||
}
|
||||
|
||||
// CheckSession - check if session is valid
|
||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
||||
func (a *Auth) CheckSession(sess string) int {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return -1
|
||||
}
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSession(key)
|
||||
a.lock.Unlock()
|
||||
return 1
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
if s.expire/(24*60*60) != newExpire/(24*60*60) {
|
||||
// update expiration time once a day
|
||||
update = true
|
||||
s.expire = newExpire
|
||||
}
|
||||
|
||||
a.lock.Unlock()
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
if a.storeSession(key, s) {
|
||||
log.Debug("Auth: updated session %s: expire=%d", sess, s.expire)
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// RemoveSession - remove session
|
||||
func (a *Auth) RemoveSession(sess string) {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.lock.Lock()
|
||||
delete(a.sessions, sess)
|
||||
a.lock.Unlock()
|
||||
a.removeSession(key)
|
||||
}
|
||||
|
||||
type loginJSON struct {
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func getSession(u *User) []byte {
|
||||
// the developers don't currently believe that using a
|
||||
// non-cryptographic RNG for the session hash salt is
|
||||
// insecure
|
||||
salt := rand.Uint32() //nolint:gosec
|
||||
d := []byte(fmt.Sprintf("%d%s%s", salt, u.Name, u.PasswordHash))
|
||||
hash := sha256.Sum256(d)
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func (a *Auth) httpCookie(req loginJSON) string {
|
||||
u := a.UserFind(req.Name, req.Password)
|
||||
if len(u.Name) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
sess := getSession(&u)
|
||||
|
||||
now := time.Now().UTC()
|
||||
expire := now.Add(cookieTTL * time.Hour)
|
||||
expstr := expire.Format(time.RFC1123)
|
||||
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
|
||||
expstr += "GMT"
|
||||
|
||||
s := session{}
|
||||
s.userName = u.Name
|
||||
s.expire = uint32(now.Unix()) + a.sessionTTL
|
||||
a.addSession(sess, &s)
|
||||
|
||||
return fmt.Sprintf("%s=%s; Path=/; HttpOnly; Expires=%s",
|
||||
sessionCookieName, hex.EncodeToString(sess), expstr)
|
||||
}
|
||||
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
req := loginJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cookie := Context.auth.httpCookie(req)
|
||||
if len(cookie) == 0 {
|
||||
log.Info("Auth: invalid user name or password: name='%s'", req.Name)
|
||||
time.Sleep(1 * time.Second)
|
||||
http.Error(w, "invalid user name or password", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Set-Cookie", cookie)
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie := r.Header.Get("Cookie")
|
||||
sess := parseCookie(cookie)
|
||||
|
||||
Context.auth.RemoveSession(sess)
|
||||
|
||||
w.Header().Set("Location", "/login.html")
|
||||
|
||||
s := fmt.Sprintf("%s=; Path=/; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
|
||||
sessionCookieName)
|
||||
w.Header().Set("Set-Cookie", s)
|
||||
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
// RegisterAuthHandlers - register handlers
|
||||
func RegisterAuthHandlers() {
|
||||
http.Handle("/control/login", postInstallHandler(ensureHandler("POST", handleLogin)))
|
||||
httpRegister("GET", "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
func parseCookie(cookie string) string {
|
||||
pairs := strings.Split(cookie, ";")
|
||||
for _, pair := range pairs {
|
||||
pair = strings.TrimSpace(pair)
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
if kv[0] == sessionCookieName {
|
||||
return kv[1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// nolint(gocyclo)
|
||||
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if r.URL.Path == "/login.html" {
|
||||
// redirect to dashboard if already authenticated
|
||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
return
|
||||
} else if r < 0 {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
|
||||
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
|
||||
strings.HasPrefix(r.URL.Path, "/login.") {
|
||||
// process as usual
|
||||
// no additional auth requirements
|
||||
} else if Context.auth != nil && Context.auth.AuthRequired() {
|
||||
// redirect to login page if not authenticated
|
||||
ok := false
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("Auth: authentification was handled by GL-Inet submodule")
|
||||
ok = true
|
||||
|
||||
} else if err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
ok = true
|
||||
} else if r < 0 {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
} else {
|
||||
// there's no Cookie, check Basic authentication
|
||||
user, pass, ok2 := r.BasicAuth()
|
||||
if ok2 {
|
||||
u := Context.auth.UserFind(user, pass)
|
||||
if len(u.Name) != 0 {
|
||||
ok = true
|
||||
} else {
|
||||
log.Info("Auth: invalid Basic Authorization value")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
||||
if glProcessRedirect(w, r) {
|
||||
log.Debug("Auth: redirected to login page by GL-Inet submodule")
|
||||
|
||||
} else {
|
||||
w.Header().Set("Location", "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("Forbidden"))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
optionalAuth(a.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
return &authHandler{handler}
|
||||
}
|
||||
|
||||
// UserAdd - add new user
|
||||
func (a *Auth) UserAdd(u *User, password string) {
|
||||
if len(password) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Error("bcrypt.GenerateFromPassword: %s", err)
|
||||
return
|
||||
}
|
||||
u.PasswordHash = string(hash)
|
||||
|
||||
a.lock.Lock()
|
||||
a.users = append(a.users, *u)
|
||||
a.lock.Unlock()
|
||||
|
||||
log.Debug("Auth: added user: %s", u.Name)
|
||||
}
|
||||
|
||||
// UserFind - find a user
|
||||
func (a *Auth) UserFind(login string, password string) User {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
for _, u := range a.users {
|
||||
if u.Name == login &&
|
||||
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
|
||||
return u
|
||||
}
|
||||
}
|
||||
return User{}
|
||||
}
|
||||
|
||||
// GetCurrentUser - get the current user
|
||||
func (a *Auth) GetCurrentUser(r *http.Request) User {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// there's no Cookie, check Basic authentication
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u := Context.auth.UserFind(user, pass)
|
||||
return u
|
||||
}
|
||||
return User{}
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
for _, u := range a.users {
|
||||
if u.Name == s.userName {
|
||||
a.lock.Unlock()
|
||||
return u
|
||||
}
|
||||
}
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
|
||||
// GetUsers - get users
|
||||
func (a *Auth) GetUsers() []User {
|
||||
a.lock.Lock()
|
||||
users := a.users
|
||||
a.lock.Unlock()
|
||||
return users
|
||||
}
|
||||
|
||||
// AuthRequired - if authentication is required
|
||||
func (a *Auth) AuthRequired() bool {
|
||||
if GLMode {
|
||||
return true
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
r := (len(a.users) != 0)
|
||||
a.lock.Unlock()
|
||||
return r
|
||||
}
|
||||
102
internal/home/auth_glinet.go
Normal file
102
internal/home/auth_glinet.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// GLMode - enable GL-Inet compatibility mode
|
||||
var GLMode bool
|
||||
|
||||
var glFilePrefix = "/tmp/gl_token_"
|
||||
|
||||
const glTokenTimeoutSeconds = 3600
|
||||
const glCookieName = "Admin-Token"
|
||||
|
||||
func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool {
|
||||
if !GLMode {
|
||||
return false
|
||||
}
|
||||
// redirect to gl-inet login
|
||||
host, _, _ := net.SplitHostPort(r.Host)
|
||||
url := "http://" + host
|
||||
log.Debug("Auth: redirecting to %s", url)
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
return true
|
||||
}
|
||||
|
||||
func glProcessCookie(r *http.Request) bool {
|
||||
if !GLMode {
|
||||
return false
|
||||
}
|
||||
|
||||
glCookie, glerr := r.Cookie(glCookieName)
|
||||
if glerr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug("Auth: GL cookie value: %s", glCookie.Value)
|
||||
if glCheckToken(glCookie.Value) {
|
||||
return true
|
||||
}
|
||||
log.Info("Auth: invalid GL cookie value: %s", glCookie)
|
||||
return false
|
||||
}
|
||||
|
||||
func glCheckToken(sess string) bool {
|
||||
tokenName := glFilePrefix + sess
|
||||
_, err := os.Stat(tokenName)
|
||||
if err != nil {
|
||||
log.Error("os.Stat: %s", err)
|
||||
return false
|
||||
}
|
||||
tokenDate := glGetTokenDate(tokenName)
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
return now <= (tokenDate + glTokenTimeoutSeconds)
|
||||
}
|
||||
|
||||
func archIsLittleEndian() bool {
|
||||
var i int32 = 0x01020304
|
||||
u := unsafe.Pointer(&i)
|
||||
pb := (*byte)(u)
|
||||
b := *pb
|
||||
return (b == 0x04)
|
||||
}
|
||||
|
||||
func glGetTokenDate(file string) uint32 {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Error("os.Open: %s", err)
|
||||
return 0
|
||||
}
|
||||
var dateToken uint32
|
||||
bs, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
log.Error("ioutil.ReadAll: %s", err)
|
||||
return 0
|
||||
}
|
||||
buf := bytes.NewBuffer(bs)
|
||||
|
||||
if archIsLittleEndian() {
|
||||
err := binary.Read(buf, binary.LittleEndian, &dateToken)
|
||||
if err != nil {
|
||||
log.Error("binary.Read: %s", err)
|
||||
return 0
|
||||
}
|
||||
} else {
|
||||
err := binary.Read(buf, binary.BigEndian, &dateToken)
|
||||
if err != nil {
|
||||
log.Error("binary.Read: %s", err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return dateToken
|
||||
}
|
||||
43
internal/home/auth_glinet_test.go
Normal file
43
internal/home/auth_glinet_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthGL(t *testing.T) {
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
|
||||
GLMode = true
|
||||
glFilePrefix = dir + "/gl_token_"
|
||||
|
||||
tval := uint32(1)
|
||||
data := make([]byte, 4)
|
||||
if archIsLittleEndian() {
|
||||
binary.LittleEndian.PutUint32(data, tval)
|
||||
} else {
|
||||
binary.BigEndian.PutUint32(data, tval)
|
||||
}
|
||||
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0644))
|
||||
assert.False(t, glCheckToken("test"))
|
||||
|
||||
tval = uint32(time.Now().UTC().Unix() + 60)
|
||||
data = make([]byte, 4)
|
||||
if archIsLittleEndian() {
|
||||
binary.LittleEndian.PutUint32(data, tval)
|
||||
} else {
|
||||
binary.BigEndian.PutUint32(data, tval)
|
||||
}
|
||||
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0644))
|
||||
r, _ := http.NewRequest("GET", "http://localhost/", nil)
|
||||
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
|
||||
assert.True(t, glProcessCookie(r))
|
||||
GLMode = false
|
||||
}
|
||||
177
internal/home/auth_test.go
Normal file
177
internal/home/auth_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func prepareTestDir() string {
|
||||
const dir = "./agh-test"
|
||||
_ = os.RemoveAll(dir)
|
||||
_ = os.MkdirAll(dir, 0755)
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []User{
|
||||
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
a := InitAuth(fn, nil, 60)
|
||||
s := session{}
|
||||
|
||||
user := User{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
|
||||
assert.True(t, a.CheckSession("notfound") == -1)
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
sess := getSession(&users[0])
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
|
||||
a.Close()
|
||||
|
||||
// load saved session
|
||||
a = InitAuth(fn, users, 60)
|
||||
|
||||
// the session is still alive
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
// reset our expiration time because CheckSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u := a.UserFind("name", "password")
|
||||
assert.True(t, len(u.Name) != 0)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60)
|
||||
assert.True(t, a.CheckSession(sessStr) == -1)
|
||||
|
||||
a.Close()
|
||||
os.Remove(fn)
|
||||
}
|
||||
|
||||
// implements http.ResponseWriter
|
||||
type testResponseWriter struct {
|
||||
hdr http.Header
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *testResponseWriter) Header() http.Header {
|
||||
return w.hdr
|
||||
}
|
||||
func (w *testResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (w *testResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
|
||||
func TestAuthHTTP(t *testing.T) {
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []User{
|
||||
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
Context.auth = InitAuth(fn, users, 60)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
}
|
||||
handler2 := optionalAuth(handler)
|
||||
w := testResponseWriter{}
|
||||
w.hdr = make(http.Header)
|
||||
r := http.Request{}
|
||||
r.Header = make(http.Header)
|
||||
r.Method = "GET"
|
||||
|
||||
// get / - we're redirected to login page
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, w.statusCode == http.StatusFound)
|
||||
assert.True(t, w.hdr.Get("Location") != "")
|
||||
assert.True(t, !handlerCalled)
|
||||
|
||||
// go to login page
|
||||
loginURL := w.hdr.Get("Location")
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
// perform login
|
||||
cookie := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
|
||||
assert.True(t, cookie != "")
|
||||
|
||||
// get /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie)
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
|
||||
// get / with basic auth
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
r.SetBasicAuth("name", "password")
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del("Authorization")
|
||||
|
||||
// get login page with a valid cookie - we're redirected to /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie)
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, w.hdr.Get("Location") != "")
|
||||
assert.True(t, !handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
|
||||
// get login page with an invalid cookie
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", "bad")
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
|
||||
Context.auth.Close()
|
||||
}
|
||||
704
internal/home/clients.go
Normal file
704
internal/home/clients.go
Normal file
@@ -0,0 +1,704 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
clientsUpdatePeriod = 10 * time.Minute
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Client information
|
||||
type Client struct {
|
||||
IDs []string
|
||||
Tags []string
|
||||
Name string
|
||||
UseOwnSettings bool // false: use global settings
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
ParentalEnabled bool
|
||||
|
||||
UseOwnBlockedServices bool // false: use global settings
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
|
||||
// Custom upstream config for this client
|
||||
// nil: not yet initialized
|
||||
// not nil, but empty: initialized, no good upstreams
|
||||
// not nil, not empty: Upstreams ready to be used
|
||||
upstreamConfig *proxy.UpstreamConfig
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
|
||||
// Client sources
|
||||
const (
|
||||
// Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS
|
||||
ClientSourceWHOIS clientSource = iota // from WHOIS
|
||||
ClientSourceRDNS // from rDNS
|
||||
ClientSourceDHCP // from DHCP
|
||||
ClientSourceARP // from 'arp -a'
|
||||
ClientSourceHostsFile // from /etc/hosts
|
||||
)
|
||||
|
||||
// ClientHost information
|
||||
type ClientHost struct {
|
||||
Host string
|
||||
Source clientSource
|
||||
WhoisInfo [][]string // [[key,value], ...]
|
||||
}
|
||||
|
||||
type clientsContainer struct {
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // IP -> client
|
||||
ipHost map[string]*ClientHost // IP -> Hostname
|
||||
lock sync.Mutex
|
||||
|
||||
allTags map[string]bool
|
||||
|
||||
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||
dhcpServer *dhcpd.Server
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
|
||||
autoHosts *util.AutoHosts // get entries from system hosts-files
|
||||
|
||||
testing bool // if TRUE, this object is used for internal tests
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server, autoHosts *util.AutoHosts) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipHost = make(map[string]*ClientHost)
|
||||
|
||||
clients.allTags = make(map[string]bool)
|
||||
for _, t := range clientTags {
|
||||
clients.allTags[t] = false
|
||||
}
|
||||
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.autoHosts = autoHosts
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
if !clients.testing {
|
||||
clients.addFromDHCP()
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
clients.autoHosts.SetOnChanged(clients.onHostsChanged)
|
||||
}
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (clients *clientsContainer) Start() {
|
||||
if !clients.testing {
|
||||
if !webHandlersRegistered {
|
||||
webHandlersRegistered = true
|
||||
clients.registerWebHandlers()
|
||||
}
|
||||
go clients.periodicUpdate()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Reload - reload auto-clients
|
||||
func (clients *clientsContainer) Reload() {
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
|
||||
type clientObject struct {
|
||||
Name string `yaml:"name"`
|
||||
Tags []string `yaml:"tags"`
|
||||
IDs []string `yaml:"ids"`
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
|
||||
UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"`
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) tagKnown(tag string) bool {
|
||||
_, ok := clients.allTags[tag]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
for _, cy := range objects {
|
||||
cli := Client{
|
||||
Name: cy.Name,
|
||||
IDs: cy.IDs,
|
||||
UseOwnSettings: !cy.UseGlobalSettings,
|
||||
FilteringEnabled: cy.FilteringEnabled,
|
||||
ParentalEnabled: cy.ParentalEnabled,
|
||||
SafeSearchEnabled: cy.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
||||
|
||||
UseOwnBlockedServices: !cy.UseGlobalBlockedServices,
|
||||
|
||||
Upstreams: cy.Upstreams,
|
||||
}
|
||||
|
||||
for _, s := range cy.BlockedServices {
|
||||
if !dnsfilter.BlockedSvcKnown(s) {
|
||||
log.Debug("Clients: skipping unknown blocked-service '%s'", s)
|
||||
continue
|
||||
}
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
}
|
||||
|
||||
for _, t := range cy.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
log.Debug("Clients: skipping unknown tag '%s'", t)
|
||||
continue
|
||||
}
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
}
|
||||
sort.Strings(cli.Tags)
|
||||
|
||||
_, err := clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Tracef("clientAdd: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||
clients.lock.Lock()
|
||||
for _, cli := range clients.list {
|
||||
cy := clientObject{
|
||||
Name: cli.Name,
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
ParentalEnabled: cli.ParentalEnabled,
|
||||
SafeSearchEnabled: cli.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||
}
|
||||
|
||||
cy.Tags = stringArrayDup(cli.Tags)
|
||||
cy.IDs = stringArrayDup(cli.IDs)
|
||||
cy.BlockedServices = stringArrayDup(cli.BlockedServices)
|
||||
cy.Upstreams = stringArrayDup(cli.Upstreams)
|
||||
|
||||
*objects = append(*objects, cy)
|
||||
}
|
||||
clients.lock.Unlock()
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
for {
|
||||
clients.Reload()
|
||||
time.Sleep(clientsUpdatePeriod)
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
switch flags {
|
||||
case dhcpd.LeaseChangedAdded,
|
||||
dhcpd.LeaseChangedAddedStatic,
|
||||
dhcpd.LeaseChangedRemovedStatic:
|
||||
clients.addFromDHCP()
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this IP already exists
|
||||
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findByIP(ip)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if source > ch.Source {
|
||||
return false // we're going to overwrite this client's info with a stronger source
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// Find searches for a client by IP
|
||||
func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findByIP(ip)
|
||||
if !ok {
|
||||
return Client{}, false
|
||||
}
|
||||
c.IDs = stringArrayDup(c.IDs)
|
||||
c.Tags = stringArrayDup(c.Tags)
|
||||
c.BlockedServices = stringArrayDup(c.BlockedServices)
|
||||
c.Upstreams = stringArrayDup(c.Upstreams)
|
||||
return c, true
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findByIP(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &config
|
||||
}
|
||||
}
|
||||
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if ipAddr == nil {
|
||||
return Client{}, false
|
||||
}
|
||||
|
||||
c, ok := clients.idIndex[ip]
|
||||
if ok {
|
||||
return *c, true
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
_, ipnet, err := net.ParseCIDR(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.Contains(ipAddr) {
|
||||
return *c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clients.dhcpServer == nil {
|
||||
return Client{}, false
|
||||
}
|
||||
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
|
||||
if macFound == nil {
|
||||
return Client{}, false
|
||||
}
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
hwAddr, err := net.ParseMAC(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(hwAddr, macFound) {
|
||||
return *c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Client{}, false
|
||||
}
|
||||
|
||||
// FindAutoClient - search for an auto-client by IP
|
||||
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if ipAddr == nil {
|
||||
return ClientHost{}, false
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if ok {
|
||||
return *ch, true
|
||||
}
|
||||
return ClientHost{}, false
|
||||
}
|
||||
|
||||
// Check if Client object's fields are correct
|
||||
func (clients *clientsContainer) check(c *Client) error {
|
||||
if len(c.Name) == 0 {
|
||||
return fmt.Errorf("invalid Name")
|
||||
}
|
||||
|
||||
if len(c.IDs) == 0 {
|
||||
return fmt.Errorf("ID required")
|
||||
}
|
||||
|
||||
for i, id := range c.IDs {
|
||||
ip := net.ParseIP(id)
|
||||
if ip != nil {
|
||||
c.IDs[i] = ip.String() // normalize IP address
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(id)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = net.ParseMAC(id)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid ID: %s", id)
|
||||
}
|
||||
|
||||
for _, t := range c.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
return fmt.Errorf("invalid tag: %s", t)
|
||||
}
|
||||
}
|
||||
sort.Strings(c.Tags)
|
||||
|
||||
err := dnsforward.ValidateUpstreams(c.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid upstream servers: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a new client object
|
||||
// Return true: success; false: client exists.
|
||||
func (clients *clientsContainer) Add(c Client) (bool, error) {
|
||||
e := clients.check(&c)
|
||||
if e != nil {
|
||||
return false, e
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok := clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check ID index
|
||||
for _, id := range c.IDs {
|
||||
c2, ok := clients.idIndex[id]
|
||||
if ok {
|
||||
return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
clients.list[c.Name] = &c
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
clients.idIndex[id] = &c
|
||||
}
|
||||
|
||||
log.Debug("Clients: added '%s': ID:%v [%d]", c.Name, c.IDs, len(clients.list))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Del removes a client
|
||||
func (clients *clientsContainer) Del(name string) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.list[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// update Name index
|
||||
delete(clients.list, name)
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Return TRUE if arrays are equal
|
||||
func arraysEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i != len(a); i++ {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Update a client
|
||||
func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
err := clients.check(&c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
old, ok := clients.list[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("client not found")
|
||||
}
|
||||
|
||||
// check Name index
|
||||
if old.Name != c.Name {
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return fmt.Errorf("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
// check IP index
|
||||
if !arraysEqual(old.IDs, c.IDs) {
|
||||
for _, id := range c.IDs {
|
||||
c2, ok := clients.idIndex[id]
|
||||
if ok && c2 != old {
|
||||
return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// update ID index
|
||||
for _, id := range old.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
for _, id := range c.IDs {
|
||||
clients.idIndex[id] = old
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
if old.Name != c.Name {
|
||||
delete(clients.list, old.Name)
|
||||
clients.list[c.Name] = old
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
c.upstreamConfig = nil
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWhoisInfo - associate WHOIS information with a client
|
||||
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findByIP(ip)
|
||||
if ok {
|
||||
log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip)
|
||||
return
|
||||
}
|
||||
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if ok {
|
||||
ch.WhoisInfo = info
|
||||
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a ClientHost implicitly so that we don't do this check again
|
||||
ch = &ClientHost{
|
||||
Source: ClientSourceWHOIS,
|
||||
}
|
||||
ch.WhoisInfo = info
|
||||
clients.ipHost[ip] = ch
|
||||
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
|
||||
}
|
||||
|
||||
// AddHost adds new IP -> Host pair
|
||||
// Use priority of the source (etc/hosts > ARP > rDNS)
|
||||
// so we overwrite existing entries with an equal or higher priority
|
||||
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
|
||||
clients.lock.Lock()
|
||||
b, e := clients.addHost(ip, host, source)
|
||||
clients.lock.Unlock()
|
||||
return b, e
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) addHost(ip, host string, source clientSource) (bool, error) {
|
||||
// check auto-clients index
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if ok && ch.Source > source {
|
||||
return false, nil
|
||||
} else if ok {
|
||||
ch.Source = source
|
||||
} else {
|
||||
ch = &ClientHost{
|
||||
Host: host,
|
||||
Source: source,
|
||||
}
|
||||
clients.ipHost[ip] = ch
|
||||
}
|
||||
log.Debug("Clients: added '%s' -> '%s' [%d]", ip, host, len(clients.ipHost))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Remove all entries that match the specified source
|
||||
func (clients *clientsContainer) rmHosts(source clientSource) int {
|
||||
n := 0
|
||||
for k, v := range clients.ipHost {
|
||||
if v.Source == source {
|
||||
delete(clients.ipHost, k)
|
||||
n++
|
||||
}
|
||||
}
|
||||
log.Debug("Clients: removed %d client aliases", n)
|
||||
return n
|
||||
}
|
||||
|
||||
// Fill clients array from system hosts-file
|
||||
func (clients *clientsContainer) addFromHostsFile() {
|
||||
hosts := clients.autoHosts.List()
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
_ = clients.rmHosts(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for ip, name := range hosts {
|
||||
ok, err := clients.addHost(ip, name, ClientSourceHostsFile)
|
||||
if err != nil {
|
||||
log.Debug("Clients: %s", err)
|
||||
}
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %d client aliases from system hosts-file", n)
|
||||
}
|
||||
|
||||
// Add IP -> Host pairs from the system's `arp -a` command output
|
||||
// The command's output is:
|
||||
// HOST (IP) at MAC on IFACE
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("arp", "-a")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
data, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Debug("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
_ = clients.rmHosts(ClientSourceARP)
|
||||
|
||||
n := 0
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, ln := range lines {
|
||||
|
||||
open := strings.Index(ln, " (")
|
||||
close := strings.Index(ln, ") ")
|
||||
if open == -1 || close == -1 || open >= close {
|
||||
continue
|
||||
}
|
||||
|
||||
host := ln[:open]
|
||||
ip := ln[open+2 : close]
|
||||
if utils.IsValidHostname(host) != nil || net.ParseIP(ip) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ok, e := clients.addHost(ip, host, ClientSourceARP)
|
||||
if e != nil {
|
||||
log.Tracef("%s", e)
|
||||
}
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
|
||||
}
|
||||
|
||||
// Add clients from DHCP that have non-empty Hostname property
|
||||
func (clients *clientsContainer) addFromDHCP() {
|
||||
if clients.dhcpServer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_ = clients.rmHosts(ClientSourceDHCP)
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if len(l.Hostname) == 0 {
|
||||
continue
|
||||
}
|
||||
ok, _ := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
log.Debug("Clients: added %d client aliases from DHCP", n)
|
||||
}
|
||||
296
internal/home/clients_http.go
Normal file
296
internal/home/clients_http.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type clientJSON struct {
|
||||
IDs []string `json:"ids"`
|
||||
Tags []string `json:"tags"`
|
||||
Name string `json:"name"`
|
||||
UseGlobalSettings bool `json:"use_global_settings"`
|
||||
FilteringEnabled bool `json:"filtering_enabled"`
|
||||
ParentalEnabled bool `json:"parental_enabled"`
|
||||
SafeSearchEnabled bool `json:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `json:"safebrowsing_enabled"`
|
||||
|
||||
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
|
||||
BlockedServices []string `json:"blocked_services"`
|
||||
|
||||
Upstreams []string `json:"upstreams"`
|
||||
|
||||
WhoisInfo map[string]interface{} `json:"whois_info"`
|
||||
|
||||
// Disallowed - if true -- client's IP is not disallowed
|
||||
// Otherwise, it is blocked.
|
||||
Disallowed bool `json:"disallowed"`
|
||||
|
||||
// DisallowedRule - the rule due to which the client is disallowed
|
||||
// If Disallowed is true, and this string is empty - it means that the client IP
|
||||
// is disallowed by the "allowed IP list", i.e. it is not included in allowed.
|
||||
DisallowedRule string `json:"disallowed_rule"`
|
||||
}
|
||||
|
||||
type clientHostJSON struct {
|
||||
IP string `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
|
||||
WhoisInfo map[string]interface{} `json:"whois_info"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
Clients []clientJSON `json:"clients"`
|
||||
AutoClients []clientHostJSON `json:"auto_clients"`
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
|
||||
// respond with information about configured clients
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http.Request) {
|
||||
data := clientListJSON{}
|
||||
|
||||
clients.lock.Lock()
|
||||
for _, c := range clients.list {
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
for ip, ch := range clients.ipHost {
|
||||
cj := clientHostJSON{
|
||||
IP: ip,
|
||||
Name: ch.Host,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
switch ch.Source {
|
||||
case ClientSourceDHCP:
|
||||
cj.Source = "DHCP"
|
||||
case ClientSourceRDNS:
|
||||
cj.Source = "rDNS"
|
||||
case ClientSourceARP:
|
||||
cj.Source = "ARP"
|
||||
case ClientSourceWHOIS:
|
||||
cj.Source = "WHOIS"
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
for _, wi := range ch.WhoisInfo {
|
||||
cj.WhoisInfo[wi[0]] = wi[1]
|
||||
}
|
||||
|
||||
data.AutoClients = append(data.AutoClients, cj)
|
||||
}
|
||||
clients.lock.Unlock()
|
||||
|
||||
data.Tags = clientTags
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Convert JSON object to Client object
|
||||
func jsonToClient(cj clientJSON) (*Client, error) {
|
||||
c := Client{
|
||||
Name: cj.Name,
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeSearchEnabled: cj.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
|
||||
Upstreams: cj.Upstreams,
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Convert Client object to JSON
|
||||
func clientToJSON(c *Client) clientJSON {
|
||||
cj := clientJSON{
|
||||
Name: c.Name,
|
||||
IDs: c.IDs,
|
||||
Tags: c.Tags,
|
||||
UseGlobalSettings: !c.UseOwnSettings,
|
||||
FilteringEnabled: c.FilteringEnabled,
|
||||
ParentalEnabled: c.ParentalEnabled,
|
||||
SafeSearchEnabled: c.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
BlockedServices: c.BlockedServices,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
}
|
||||
return cj
|
||||
}
|
||||
|
||||
// Convert ClientHost object to JSON
|
||||
func clientHostToJSON(ip string, ch ClientHost) clientJSON {
|
||||
cj := clientJSON{
|
||||
Name: ch.Host,
|
||||
IDs: []string{ip},
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
for _, wi := range ch.WhoisInfo {
|
||||
cj.WhoisInfo[wi[0]] = wi[1]
|
||||
}
|
||||
return cj
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
c, err := jsonToClient(cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
ok, err := clients.Add(*c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Client already exists")
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Remove client
|
||||
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
if err != nil || len(cj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.Del(cj.Name) {
|
||||
httpError(w, http.StatusBadRequest, "Client not found")
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
type updateJSON struct {
|
||||
Name string `json:"name"`
|
||||
Data clientJSON `json:"data"`
|
||||
}
|
||||
|
||||
// Update client's properties
|
||||
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var dj updateJSON
|
||||
err = json.Unmarshal(body, &dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
if len(dj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "Invalid request")
|
||||
return
|
||||
}
|
||||
|
||||
c, err := jsonToClient(dj.Data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.Update(dj.Name, *c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// Get the list of clients by IP address list
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]interface{}{}
|
||||
for i := 0; ; i++ {
|
||||
ip := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if len(ip) == 0 {
|
||||
break
|
||||
}
|
||||
el := map[string]interface{}{}
|
||||
c, ok := clients.Find(ip)
|
||||
if !ok {
|
||||
ch, ok := clients.FindAutoClient(ip)
|
||||
if !ok {
|
||||
continue // a client with this IP isn't found
|
||||
}
|
||||
cj := clientHostToJSON(ip, ch)
|
||||
|
||||
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
|
||||
el[ip] = cj
|
||||
} else {
|
||||
cj := clientToJSON(&c)
|
||||
|
||||
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
|
||||
el[ip] = cj
|
||||
}
|
||||
|
||||
data = append(data, el)
|
||||
}
|
||||
|
||||
js, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(js)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
func (clients *clientsContainer) registerWebHandlers() {
|
||||
httpRegister("GET", "/control/clients", clients.handleGetClients)
|
||||
httpRegister("POST", "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister("POST", "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister("POST", "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister("GET", "/control/clients/find", clients.handleFindClient)
|
||||
}
|
||||
27
internal/home/clients_tags.go
Normal file
27
internal/home/clients_tags.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package home
|
||||
|
||||
var clientTags = []string{
|
||||
"device_audio",
|
||||
"device_camera",
|
||||
"device_gameconsole",
|
||||
"device_laptop",
|
||||
"device_nas", // Network-attached Storage
|
||||
"device_other",
|
||||
"device_pc",
|
||||
"device_phone",
|
||||
"device_printer",
|
||||
"device_securityalarm",
|
||||
"device_tablet",
|
||||
"device_tv",
|
||||
|
||||
"os_android",
|
||||
"os_ios",
|
||||
"os_linux",
|
||||
"os_macos",
|
||||
"os_other",
|
||||
"os_windows",
|
||||
|
||||
"user_admin",
|
||||
"user_child",
|
||||
"user_regular",
|
||||
}
|
||||
266
internal/home/clients_test.go
Normal file
266
internal/home/clients_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
var c Client
|
||||
var e error
|
||||
var b bool
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add
|
||||
c = Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
}
|
||||
b, e = clients.Add(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("Add #1")
|
||||
}
|
||||
|
||||
// add #2
|
||||
c = Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client2",
|
||||
}
|
||||
b, e = clients.Add(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("Add #2")
|
||||
}
|
||||
|
||||
c, b = clients.Find("1.1.1.1")
|
||||
assert.True(t, b && c.Name == "client1")
|
||||
|
||||
c, b = clients.Find("1:2:3::4")
|
||||
assert.True(t, b && c.Name == "client1")
|
||||
|
||||
c, b = clients.Find("2.2.2.2")
|
||||
assert.True(t, b && c.Name == "client2")
|
||||
|
||||
// failed add - name in use
|
||||
c = Client{
|
||||
IDs: []string{"1.2.3.5"},
|
||||
Name: "client1",
|
||||
}
|
||||
b, _ = clients.Add(c)
|
||||
if b {
|
||||
t.Fatalf("Add - name in use")
|
||||
}
|
||||
|
||||
// failed add - ip in use
|
||||
c = Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
}
|
||||
b, e = clients.Add(c)
|
||||
if b || e == nil {
|
||||
t.Fatalf("Add - ip in use")
|
||||
}
|
||||
|
||||
// get
|
||||
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
||||
|
||||
// failed update - no such name
|
||||
c.IDs = []string{"1.2.3.0"}
|
||||
c.Name = "client3"
|
||||
if clients.Update("client3", c) == nil {
|
||||
t.Fatalf("Update")
|
||||
}
|
||||
|
||||
// failed update - name in use
|
||||
c.IDs = []string{"1.2.3.0"}
|
||||
c.Name = "client2"
|
||||
if clients.Update("client1", c) == nil {
|
||||
t.Fatalf("Update - name in use")
|
||||
}
|
||||
|
||||
// failed update - ip in use
|
||||
c.IDs = []string{"2.2.2.2"}
|
||||
c.Name = "client1"
|
||||
if clients.Update("client1", c) == nil {
|
||||
t.Fatalf("Update - ip in use")
|
||||
}
|
||||
|
||||
// update
|
||||
c.IDs = []string{"1.1.1.2"}
|
||||
c.Name = "client1"
|
||||
if clients.Update("client1", c) != nil {
|
||||
t.Fatalf("Update")
|
||||
}
|
||||
|
||||
// get after update
|
||||
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
|
||||
// update - rename
|
||||
c.IDs = []string{"1.1.1.2"}
|
||||
c.Name = "client1-renamed"
|
||||
c.UseOwnSettings = true
|
||||
assert.True(t, clients.Update("client1", c) == nil)
|
||||
c = Client{}
|
||||
c, b = clients.Find("1.1.1.2")
|
||||
assert.True(t, b && c.Name == "client1-renamed" && c.IDs[0] == "1.1.1.2" && c.UseOwnSettings)
|
||||
assert.True(t, clients.list["client1"] == nil)
|
||||
|
||||
// failed remove - no such name
|
||||
if clients.Del("client3") {
|
||||
t.Fatalf("Del - no such name")
|
||||
}
|
||||
|
||||
// remove
|
||||
assert.True(t, !(!clients.Del("client1-renamed") || clients.Exists("1.1.1.2", ClientSourceHostsFile)))
|
||||
|
||||
// add host client
|
||||
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost")
|
||||
}
|
||||
|
||||
// failed add - ip exists
|
||||
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
if b || e != nil {
|
||||
t.Fatalf("clientAddHost - ip exists")
|
||||
}
|
||||
|
||||
// overwrite with new data
|
||||
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost - overwrite with new data")
|
||||
}
|
||||
|
||||
// overwrite with new data (higher priority)
|
||||
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
|
||||
}
|
||||
|
||||
// get
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
}
|
||||
|
||||
func TestClientsWhois(t *testing.T) {
|
||||
var c Client
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
|
||||
// set whois info on new client
|
||||
clients.SetWhoisInfo("1.1.1.255", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val")
|
||||
|
||||
// set whois info on existing auto-client
|
||||
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
|
||||
|
||||
// Check that we cannot set whois info on a manually-added client
|
||||
c = Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
}
|
||||
_, _ = clients.Add(c)
|
||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
|
||||
_ = clients.Del("client1")
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
var c Client
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// some test variables
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
testIP := "1.2.3.4"
|
||||
|
||||
// add a client
|
||||
c = Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
}
|
||||
ok, err := clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// add an auto-client with the same IP - it's allowed
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// now some more complicated stuff
|
||||
// first, init a DHCP server with a single static lease
|
||||
config := dhcpd.ServerConfig{
|
||||
DBFilePath: "leases.db",
|
||||
}
|
||||
defer func() { _ = os.Remove("leases.db") }()
|
||||
clients.dhcpServer = dhcpd.Create(config)
|
||||
err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{
|
||||
HWAddr: mac,
|
||||
IP: net.ParseIP(testIP).To4(),
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// add a new client with the same IP as for a client with MAC
|
||||
c = Client{
|
||||
IDs: []string{testIP},
|
||||
Name: "client2",
|
||||
}
|
||||
ok, err = clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// add a new client with the IP from the client1's IP range
|
||||
c = Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
}
|
||||
ok, err = clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add client with upstreams
|
||||
client := Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
}
|
||||
ok, err := clients.Add(client)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, 1, len(config.Upstreams))
|
||||
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
||||
}
|
||||
295
internal/home/config.go
Normal file
295
internal/home/config.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// logSettings
|
||||
type logSettings struct {
|
||||
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
|
||||
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
|
||||
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
|
||||
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
|
||||
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
|
||||
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
// Raw file data to avoid re-reading of configuration file
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
Users []User `yaml:"users"` // Users that can access HTTP server
|
||||
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
|
||||
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
|
||||
|
||||
// TTL for a web session (in hours)
|
||||
// An active session is automatically refreshed once a day.
|
||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
|
||||
Filters []filter `yaml:"filters"`
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
||||
// Note: this array is filled only before file read/write and then it's cleared
|
||||
Clients []clientObject `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
|
||||
sync.RWMutex `yaml:"-"`
|
||||
|
||||
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type dnsConfig struct {
|
||||
BindHost string `yaml:"bind_host"`
|
||||
Port int `yaml:"port"`
|
||||
|
||||
// time interval for statistics (in days)
|
||||
StatsInterval uint32 `yaml:"statistics_interval"`
|
||||
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
|
||||
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"` // if true, query log will be written to a file
|
||||
QueryLogInterval uint32 `yaml:"querylog_interval"` // time interval for query log (in days)
|
||||
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
DnsfilterConf dnsfilter.Config `yaml:",inline"`
|
||||
}
|
||||
|
||||
type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled
|
||||
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
|
||||
// Allow DOH queries via unencrypted HTTP (e.g. for reverse proxying)
|
||||
AllowUnencryptedDOH bool `yaml:"allow_unencrypted_doh" json:"allow_unencrypted_doh"`
|
||||
|
||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
}
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
var config = configuration{
|
||||
BindPort: 3000,
|
||||
BindHost: "0.0.0.0",
|
||||
DNS: dnsConfig{
|
||||
BindHost: "0.0.0.0",
|
||||
Port: 53,
|
||||
StatsInterval: 1,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of dnsfilter features
|
||||
BlockingMode: "default", // mode how to answer filtered requests
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
},
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
PortHTTPS: 443,
|
||||
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
|
||||
PortDNSOverQUIC: 784,
|
||||
},
|
||||
logSettings: logSettings{
|
||||
LogCompress: false,
|
||||
LogLocalTime: false,
|
||||
LogMaxBackups: 0,
|
||||
LogMaxSize: 100,
|
||||
LogMaxAge: 3,
|
||||
},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// initConfig initializes default configuration for the current OS&ARCH
|
||||
func initConfig() {
|
||||
config.WebSessionTTLHours = 30 * 24
|
||||
|
||||
config.DNS.QueryLogEnabled = true
|
||||
config.DNS.QueryLogFileEnabled = true
|
||||
config.DNS.QueryLogInterval = 90
|
||||
config.DNS.QueryLogMemSize = 1000
|
||||
|
||||
config.DNS.CacheSize = 4 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.CacheTime = 30
|
||||
config.Filters = defaultFilters()
|
||||
|
||||
config.DHCP.Conf4.LeaseDuration = 86400
|
||||
config.DHCP.Conf4.ICMPTimeout = 1000
|
||||
config.DHCP.Conf6.LeaseDuration = 86400
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
func (c *configuration) getConfigFilename() string {
|
||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Error("unexpected error while config file path evaluation: %s", err)
|
||||
}
|
||||
configFile = Context.configFilename
|
||||
}
|
||||
if !filepath.IsAbs(configFile) {
|
||||
configFile = filepath.Join(Context.workDir, configFile)
|
||||
}
|
||||
return configFile
|
||||
}
|
||||
|
||||
// getLogSettings reads logging settings from the config file.
|
||||
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
|
||||
func getLogSettings() logSettings {
|
||||
l := logSettings{}
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return l
|
||||
}
|
||||
err = yaml.Unmarshal(yamlFile, &l)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get logging settings from the configuration: %s", err)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// parseConfig loads configuration from the YAML file
|
||||
func parseConfig() error {
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Reading config file: %s", configFile)
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.fileData = nil
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't parse config file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readConfigFile reads config file contents if it exists
|
||||
func readConfigFile() ([]byte, error) {
|
||||
if len(config.fileData) != 0 {
|
||||
return config.fileData, nil
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
d, err := ioutil.ReadFile(configFile)
|
||||
if err != nil {
|
||||
log.Error("Couldn't read config file %s: %s", configFile, err)
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
Context.clients.WriteDiskConfig(&config.Clients)
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
config.TLS = tlsConf
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
sdc := stats.DiskConfig{}
|
||||
Context.stats.WriteDiskConfig(&sdc)
|
||||
config.DNS.StatsInterval = sdc.Interval
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
dc := querylog.Config{}
|
||||
Context.queryLog.WriteDiskConfig(&dc)
|
||||
config.DNS.QueryLogEnabled = dc.Enabled
|
||||
config.DNS.QueryLogFileEnabled = dc.FileEnabled
|
||||
config.DNS.QueryLogInterval = dc.Interval
|
||||
config.DNS.QueryLogMemSize = dc.MemSize
|
||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
c := dnsfilter.Config{}
|
||||
Context.dnsFilter.WriteDiskConfig(&c)
|
||||
config.DNS.DnsfilterConf = c
|
||||
}
|
||||
|
||||
if Context.dnsServer != nil {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
Context.dnsServer.WriteDiskConfig(&c)
|
||||
config.DNS.FilteringConfig = c
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
c := dhcpd.ServerConfig{}
|
||||
Context.dhcpServer.WriteDiskConfig(&c)
|
||||
config.DHCP = c
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
config.Clients = nil
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
return err
|
||||
}
|
||||
err = file.SafeWrite(configFile, yamlText)
|
||||
if err != nil {
|
||||
log.Error("Couldn't save YAML config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
234
internal/home/control.go
Normal file
234
internal/home/control.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
|
||||
// ----------------
|
||||
// helper functions
|
||||
// ----------------
|
||||
|
||||
func returnOK(w http.ResponseWriter) {
|
||||
_, err := fmt.Fprintf(w, "OK\n")
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info(text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// ---------------
|
||||
// dns run control
|
||||
// ---------------
|
||||
func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||
if config.DNS.Port != 53 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
|
||||
}
|
||||
*dnsAddresses = append(*dnsAddresses, addr)
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
if Context.dnsServer != nil {
|
||||
Context.dnsServer.WriteDiskConfig(&c)
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"dns_addresses": getDNSAddresses(),
|
||||
"http_port": config.BindPort,
|
||||
"dns_port": config.DNS.Port,
|
||||
"running": isRunning(),
|
||||
"version": versionString,
|
||||
"language": config.Language,
|
||||
|
||||
"protection_enabled": c.ProtectionEnabled,
|
||||
}
|
||||
data["dhcp_available"] = (Context.dhcpServer != nil)
|
||||
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type profileJSON struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
pj := profileJSON{}
|
||||
u := Context.auth.GetCurrentUser(r)
|
||||
pj.Name = u.Name
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// registration of handlers
|
||||
// ------------------------
|
||||
func registerControlHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/status", handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
|
||||
httpRegister(http.MethodPost, "/control/update", handleUpdate)
|
||||
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
|
||||
|
||||
// No auth is necessary for DOH/DOT configurations
|
||||
http.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoh))
|
||||
http.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDot))
|
||||
RegisterAuthHandlers()
|
||||
}
|
||||
|
||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
if len(method) == 0 {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
http.HandleFunc(url, postInstall(handler))
|
||||
return
|
||||
}
|
||||
|
||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
}
|
||||
|
||||
// ----------------------------------
|
||||
// helper functions for HTTP handlers
|
||||
// ----------------------------------
|
||||
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("%s %v", r.Method, r.URL)
|
||||
|
||||
if r.Method != method {
|
||||
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("POST", handler)
|
||||
}
|
||||
|
||||
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("GET", handler)
|
||||
}
|
||||
|
||||
// Bridge between http.Handler object and Go function
|
||||
type httpHandler struct {
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.handler(w, r)
|
||||
}
|
||||
|
||||
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||
h := httpHandler{}
|
||||
h.handler = ensure(method, handler)
|
||||
return &h
|
||||
}
|
||||
|
||||
// preInstall lets the handler run only if firstRun is true, no redirects
|
||||
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !Context.firstRun {
|
||||
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
|
||||
type preInstallHandlerStruct struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
preInstall(p.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
// preInstallHandler returns http.Handler interface for preInstall wrapper
|
||||
func preInstallHandler(handler http.Handler) http.Handler {
|
||||
return &preInstallHandlerStruct{handler}
|
||||
}
|
||||
|
||||
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
|
||||
// it also enforces HTTPS if it is enabled and configured
|
||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if Context.firstRun &&
|
||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||
!strings.HasPrefix(r.URL.Path, "/assets/") {
|
||||
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// enforce https?
|
||||
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
|
||||
// yes, and we want host from host:port
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// no port in host
|
||||
host = r.Host
|
||||
}
|
||||
// construct new URL to redirect to
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
type postInstallHandlerStruct struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
postInstall(p.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
func postInstallHandler(handler http.Handler) http.Handler {
|
||||
return &postInstallHandlerStruct{handler}
|
||||
}
|
||||
391
internal/home/control_filtering.go
Normal file
391
internal/home/control_filtering.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// isValidURL - return TRUE if URL or file path is valid
|
||||
func isValidURL(rawurl string) bool {
|
||||
if filepath.IsAbs(rawurl) {
|
||||
// this is a file path
|
||||
return util.FileExists(rawurl)
|
||||
}
|
||||
|
||||
url, err := url.ParseRequestURI(rawurl)
|
||||
if err != nil {
|
||||
return false //Couldn't even parse the rawurl
|
||||
}
|
||||
if len(url.Scheme) == 0 {
|
||||
return false //No Scheme found
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type filterAddJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterAddJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidURL(fj.URL) {
|
||||
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if filterExists(fj.URL) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
return
|
||||
}
|
||||
|
||||
// Set necessary properties
|
||||
filt := filter{
|
||||
Enabled: true,
|
||||
URL: fj.URL,
|
||||
Name: fj.Name,
|
||||
white: fj.Whitelist,
|
||||
}
|
||||
filt.ID = assignUniqueFilterID()
|
||||
|
||||
// Download the filter contents
|
||||
ok, err := f.update(&filt)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
|
||||
return
|
||||
}
|
||||
|
||||
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
|
||||
if !filterAdd(filt) {
|
||||
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type request struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
req := request{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// go through each element and delete if url matches
|
||||
config.Lock()
|
||||
newFilters := []filter{}
|
||||
filters := &config.Filters
|
||||
if req.Whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
}
|
||||
for _, filter := range *filters {
|
||||
if filter.URL != req.URL {
|
||||
newFilters = append(newFilters, filter)
|
||||
} else {
|
||||
err := os.Rename(filter.Path(), filter.Path()+".old")
|
||||
if err != nil {
|
||||
log.Error("os.Rename: %s: %s", filter.Path(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update the configuration after removing filter files
|
||||
*filters = newFilters
|
||||
config.Unlock()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
|
||||
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
|
||||
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
|
||||
}
|
||||
|
||||
type filterURLJSON struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filterURLReq struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
Data filterURLJSON `json:"data"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterURLReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidURL(fj.Data.URL) {
|
||||
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filt := filter{
|
||||
Enabled: fj.Data.Enabled,
|
||||
Name: fj.Data.Name,
|
||||
URL: fj.Data.URL,
|
||||
}
|
||||
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
if (status & statusFound) == 0 {
|
||||
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if (status & statusURLExists) != 0 {
|
||||
http.Error(w, "URL already exists", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
restart := false
|
||||
if (status & statusEnabledChanged) != 0 {
|
||||
// we must add or remove filter rules
|
||||
restart = true
|
||||
}
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := FilterRefreshBlocklists
|
||||
if fj.Whitelist {
|
||||
flags = FilterRefreshAllowlists
|
||||
}
|
||||
nUpdated, _ := f.refreshFilters(flags, true)
|
||||
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
||||
// if not - we restart the filtering ourselves
|
||||
restart = false
|
||||
if nUpdated == 0 {
|
||||
restart = true
|
||||
}
|
||||
}
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
config.UserRules = strings.Split(string(body), "\n")
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
type Req struct {
|
||||
White bool `json:"whitelist"`
|
||||
}
|
||||
type Resp struct {
|
||||
Updated int `json:"updated"`
|
||||
}
|
||||
resp := Resp{}
|
||||
var err error
|
||||
|
||||
req := Req{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
Context.controlLock.Unlock()
|
||||
flags := FilterRefreshBlocklists
|
||||
if req.White {
|
||||
flags = FilterRefreshAllowlists
|
||||
}
|
||||
resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
|
||||
Context.controlLock.Lock()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
|
||||
type filteringConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Filters []filterJSON `json:"filters"`
|
||||
WhitelistFilters []filterJSON `json:"whitelist_filters"`
|
||||
UserRules []string `json:"user_rules"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
fj := filterJSON{
|
||||
ID: f.ID,
|
||||
Enabled: f.Enabled,
|
||||
URL: f.URL,
|
||||
Name: f.Name,
|
||||
RulesCount: uint32(f.RulesCount),
|
||||
}
|
||||
|
||||
if !f.LastUpdated.IsZero() {
|
||||
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
return fj
|
||||
}
|
||||
|
||||
// Get filtering configuration
|
||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := filteringConfig{}
|
||||
config.RLock()
|
||||
resp.Enabled = config.DNS.FilteringEnabled
|
||||
resp.Interval = config.DNS.FiltersUpdateIntervalHours
|
||||
for _, f := range config.Filters {
|
||||
fj := filterToJSON(f)
|
||||
resp.Filters = append(resp.Filters, fj)
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
fj := filterToJSON(f)
|
||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||
}
|
||||
resp.UserRules = config.UserRules
|
||||
config.RUnlock()
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "http write: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set filtering configuration
|
||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := filteringConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
||||
httpError(w, http.StatusBadRequest, "Unsupported interval")
|
||||
return
|
||||
}
|
||||
|
||||
config.DNS.FilteringEnabled = req.Enabled
|
||||
config.DNS.FiltersUpdateIntervalHours = req.Interval
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
Reason string `json:"reason"`
|
||||
FilterID int64 `json:"filter_id"`
|
||||
Rule string `json:"rule"`
|
||||
|
||||
// for FilteredBlockedService:
|
||||
SvcName string `json:"service_name"`
|
||||
|
||||
// for ReasonRewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
host := q.Get("name")
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := checkHostResp{}
|
||||
resp.Reason = result.Reason.String()
|
||||
resp.FilterID = result.FilterID
|
||||
resp.Rule = result.Rule
|
||||
resp.SvcName = result.ServiceName
|
||||
resp.CanonName = result.CanonName
|
||||
resp.IPList = result.IPList
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (f *Filtering) RegisterFilteringHandlers() {
|
||||
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
|
||||
}
|
||||
|
||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||
}
|
||||
372
internal/home/control_install.go
Normal file
372
internal/home/control_install.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type firstRunData struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces map[string]interface{} `json:"interfaces"`
|
||||
}
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
MTU int `json:"mtu"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Addresses []string `json:"ip_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
// Get initial installation settings
|
||||
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
data := firstRunData{}
|
||||
data.WebPort = 80
|
||||
data.DNSPort = 53
|
||||
|
||||
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
data.Interfaces = make(map[string]interface{})
|
||||
for _, iface := range ifaces {
|
||||
ifaceJSON := netInterfaceJSON{
|
||||
Name: iface.Name,
|
||||
MTU: iface.MTU,
|
||||
HardwareAddr: iface.HardwareAddr,
|
||||
Addresses: iface.Addresses,
|
||||
Flags: iface.Flags,
|
||||
}
|
||||
data.Interfaces[iface.Name] = ifaceJSON
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type checkConfigReqEnt struct {
|
||||
Port int `json:"port"`
|
||||
IP string `json:"ip"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
type checkConfigReq struct {
|
||||
Web checkConfigReqEnt `json:"web"`
|
||||
DNS checkConfigReqEnt `json:"dns"`
|
||||
SetStaticIP bool `json:"set_static_ip"`
|
||||
}
|
||||
|
||||
type checkConfigRespEnt struct {
|
||||
Status string `json:"status"`
|
||||
CanAutofix bool `json:"can_autofix"`
|
||||
}
|
||||
type staticIPJSON struct {
|
||||
Static string `json:"static"`
|
||||
IP string `json:"ip"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
type checkConfigResp struct {
|
||||
Web checkConfigRespEnt `json:"web"`
|
||||
DNS checkConfigRespEnt `json:"dns"`
|
||||
StaticIP staticIPJSON `json:"static_ip"`
|
||||
}
|
||||
|
||||
// Check if ports are available, respond with results
|
||||
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := checkConfigReq{}
|
||||
respData := checkConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
|
||||
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
||||
if err != nil {
|
||||
respData.Web.Status = fmt.Sprintf("%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqData.DNS.Port != 0 {
|
||||
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
|
||||
if util.ErrorIsAddrInUse(err) {
|
||||
canAutofix := checkDNSStubListener()
|
||||
if canAutofix && reqData.DNS.Autofix {
|
||||
|
||||
err = disableDNSStubListener()
|
||||
if err != nil {
|
||||
log.Error("Couldn't disable DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
respData.DNS.CanAutofix = canAutofix
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
respData.DNS.Status = fmt.Sprintf("%v", err)
|
||||
} else if reqData.DNS.IP != "0.0.0.0" {
|
||||
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(respData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleStaticIP - handles static IP request
|
||||
// It either checks if we have a static IP
|
||||
// Or if set=true, it tries to set it
|
||||
func handleStaticIP(ip string, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := util.GetInterfaceByIP(ip)
|
||||
resp.Static = "no"
|
||||
|
||||
if len(interfaceName) == 0 {
|
||||
resp.Static = "error"
|
||||
resp.Error = fmt.Sprintf("Couldn't find network interface by IP %s", ip)
|
||||
return resp
|
||||
}
|
||||
|
||||
if set {
|
||||
// Try to set static IP for the specified interface
|
||||
err := dhcpd.SetStaticIP(interfaceName)
|
||||
if err != nil {
|
||||
resp.Static = "error"
|
||||
resp.Error = err.Error()
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
// Fallthrough here even if we set static IP
|
||||
// Check if we have a static IP and return the details
|
||||
isStaticIP, err := dhcpd.HasStaticIP(interfaceName)
|
||||
if err != nil {
|
||||
resp.Static = "error"
|
||||
resp.Error = err.Error()
|
||||
} else {
|
||||
if isStaticIP {
|
||||
resp.Static = "yes"
|
||||
}
|
||||
resp.IP = util.GetSubnet(interfaceName)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Check if DNSStubListener is active
|
||||
func checkDNSStubListener() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Info("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return false
|
||||
}
|
||||
|
||||
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Info("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
const resolvedConfPath = "/etc/systemd/resolved.conf.d/adguardhome.conf"
|
||||
const resolvedConfData = `[Resolve]
|
||||
DNS=127.0.0.1
|
||||
DNSStubListener=no
|
||||
`
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// Deactivate DNSStubListener
|
||||
func disableDNSStubListener() error {
|
||||
dir := filepath.Dir(resolvedConfPath)
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.MkdirAll: %s: %s", dir, err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(resolvedConfPath, []byte(resolvedConfData), 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ioutil.WriteFile: %s: %s", resolvedConfPath, err)
|
||||
}
|
||||
|
||||
_ = os.Rename(resolvConfPath, resolvConfPath+".backup")
|
||||
err = os.Symlink("/run/systemd/resolve/resolv.conf", resolvConfPath)
|
||||
if err != nil {
|
||||
_ = os.Remove(resolvedConfPath) // remove the file we've just created
|
||||
return fmt.Errorf("os.Symlink: %s: %s", resolvConfPath, err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("process %s exited with an error: %d",
|
||||
cmd.Path, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
type applyConfigReq struct {
|
||||
Web applyConfigReqEnt `json:"web"`
|
||||
DNS applyConfigReqEnt `json:"dns"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Copy installation parameters between two configuration objects
|
||||
func copyInstallSettings(dst *configuration, src *configuration) {
|
||||
dst.BindHost = src.BindHost
|
||||
dst.BindPort = src.BindPort
|
||||
dst.DNS.BindHost = src.DNS.BindHost
|
||||
dst.DNS.Port = src.DNS.Port
|
||||
}
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
newSettings := applyConfigReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&newSettings)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'configure' JSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if newSettings.Web.Port == 0 || newSettings.DNS.Port == 0 {
|
||||
httpError(w, http.StatusBadRequest, "port value can't be 0")
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTP := true
|
||||
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
|
||||
// no need to rebind
|
||||
restartHTTP = false
|
||||
}
|
||||
|
||||
// validate that hosts and ports are bindable
|
||||
if restartHTTP {
|
||||
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
|
||||
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var curConfig configuration
|
||||
copyInstallSettings(&curConfig, &config)
|
||||
|
||||
Context.firstRun = false
|
||||
config.BindHost = newSettings.Web.IP
|
||||
config.BindPort = newSettings.Web.Port
|
||||
config.DNS.BindHost = newSettings.DNS.IP
|
||||
config.DNS.Port = newSettings.DNS.Port
|
||||
|
||||
err = StartMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(&config, &curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
u := User{}
|
||||
u.Name = newSettings.Username
|
||||
Context.auth.UserAdd(&u, newSettings.Password)
|
||||
|
||||
err = config.write()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(&config, &curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
web.conf.firstRun = false
|
||||
web.conf.BindHost = newSettings.Web.IP
|
||||
web.conf.BindPort = newSettings.Web.Port
|
||||
|
||||
registerControlHandlers()
|
||||
|
||||
returnOK(w)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTP {
|
||||
go func() {
|
||||
_ = Context.web.httpServer.Shutdown(context.TODO())
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (web *Web) registerInstallHandlers() {
|
||||
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
|
||||
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
}
|
||||
77
internal/home/control_test.go
Normal file
77
internal/home/control_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Tests performed:
|
||||
. Bad certificate
|
||||
. Bad private key
|
||||
. Valid certificate & private key */
|
||||
func TestValidateCertificates(t *testing.T) {
|
||||
var data tlsConfigStatus
|
||||
|
||||
// bad cert
|
||||
data = validateCertificates("bad cert", "", "")
|
||||
if !(data.WarningValidation != "" &&
|
||||
!data.ValidCert &&
|
||||
!data.ValidChain) {
|
||||
t.Fatalf("bad cert: validateCertificates(): %v", data)
|
||||
}
|
||||
|
||||
// bad priv key
|
||||
data = validateCertificates("", "bad priv key", "")
|
||||
if !(data.WarningValidation != "" &&
|
||||
!data.ValidKey) {
|
||||
t.Fatalf("bad priv key: validateCertificates(): %v", data)
|
||||
}
|
||||
|
||||
// valid cert & priv key
|
||||
CertificateChain := `-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
|
||||
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
|
||||
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
|
||||
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
|
||||
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
|
||||
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
|
||||
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||
-----END CERTIFICATE-----`
|
||||
PrivateKey := `-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
|
||||
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
|
||||
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
|
||||
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
|
||||
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
|
||||
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
|
||||
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
|
||||
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
|
||||
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||
kXS9jgARhhiWXJrk
|
||||
-----END PRIVATE KEY-----`
|
||||
data = validateCertificates(CertificateChain, PrivateKey, "")
|
||||
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
|
||||
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
|
||||
if !(data.WarningValidation != "" /* self signed */ &&
|
||||
data.ValidCert &&
|
||||
!data.ValidChain &&
|
||||
data.ValidKey &&
|
||||
data.KeyType == "RSA" &&
|
||||
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.NotBefore == notBefore &&
|
||||
data.NotAfter == notAfter &&
|
||||
// data.DNSNames[0] == &&
|
||||
data.ValidPair) {
|
||||
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
|
||||
}
|
||||
}
|
||||
164
internal/home/control_update.go
Normal file
164
internal/home/control_update.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/update"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type getVersionJSONRequest struct {
|
||||
RecheckNow bool `json:"recheck_now"`
|
||||
}
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if Context.disableUpdate {
|
||||
resp := make(map[string]interface{})
|
||||
resp["disabled"] = true
|
||||
d, _ := json.Marshal(resp)
|
||||
_, _ = w.Write(d)
|
||||
return
|
||||
}
|
||||
|
||||
req := getVersionJSONRequest{}
|
||||
var err error
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var info update.VersionInfo
|
||||
for i := 0; i != 3; i++ {
|
||||
Context.controlLock.Lock()
|
||||
info, err = Context.updater.GetVersionResponse(req.RecheckNow)
|
||||
Context.controlLock.Unlock()
|
||||
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
|
||||
// This case may happen while we're restarting DNS server
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/934
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(getVersionResp(info))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform an update procedure to the latest available version
|
||||
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
|
||||
if len(Context.updater.NewVersion) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.DoUpdate()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
go finishUpdate()
|
||||
}
|
||||
|
||||
// Convert version.json data to our JSON response
|
||||
func getVersionResp(info update.VersionInfo) []byte {
|
||||
ret := make(map[string]interface{})
|
||||
ret["can_autoupdate"] = false
|
||||
ret["new_version"] = info.NewVersion
|
||||
ret["announcement"] = info.Announcement
|
||||
ret["announcement_url"] = info.AnnouncementURL
|
||||
|
||||
if info.CanAutoUpdate {
|
||||
canUpdate := true
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
|
||||
if runtime.GOOS != "windows" &&
|
||||
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
|
||||
tlsConf.PortDNSOverTLS < 1024 ||
|
||||
tlsConf.PortDNSOverQUIC < 1024)) ||
|
||||
config.BindPort < 1024 ||
|
||||
config.DNS.Port < 1024) {
|
||||
// On UNIX, if we're running under a regular user,
|
||||
// but with CAP_NET_BIND_SERVICE set on a binary file,
|
||||
// and we're listening on ports <1024,
|
||||
// we won't be able to restart after we replace the binary file,
|
||||
// because we'll lose CAP_NET_BIND_SERVICE capability.
|
||||
canUpdate, _ = util.HaveAdminRights()
|
||||
}
|
||||
ret["can_autoupdate"] = canUpdate
|
||||
}
|
||||
|
||||
d, _ := json.Marshal(ret)
|
||||
return d
|
||||
}
|
||||
|
||||
// Complete an update procedure
|
||||
func finishUpdate() {
|
||||
log.Info("Stopping all tasks")
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
if runtime.GOOS == "windows" {
|
||||
exeName = "AdGuardHome.exe"
|
||||
}
|
||||
curBinName := filepath.Join(Context.workDir, exeName)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if Context.runningAsService {
|
||||
// Note:
|
||||
// we can't restart the service via "kardianos/service" package - it kills the process first
|
||||
// we can't start a new instance - Windows doesn't allow it
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cmd := exec.Command(curBinName, os.Args[1:]...)
|
||||
log.Info("Restarting: %v", cmd.Args)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
} else {
|
||||
log.Info("Restarting: %v", os.Args)
|
||||
err := syscall.Exec(curBinName, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("syscall.Exec() failed: %s", err)
|
||||
}
|
||||
// Unreachable code
|
||||
}
|
||||
}
|
||||
102
internal/home/control_update_test.go
Normal file
102
internal/home/control_update_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// +build ignore
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDoUpdate(t *testing.T) {
|
||||
config.DNS.Port = 0
|
||||
Context.workDir = "..." // set absolute path
|
||||
newver := "v0.96"
|
||||
|
||||
data := `{
|
||||
"version": "v0.96",
|
||||
"announcement": "AdGuard Home v0.96 is now available!",
|
||||
"announcement_url": "",
|
||||
"download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
}`
|
||||
uu, err := getUpdateInfo([]byte(data))
|
||||
if err != nil {
|
||||
t.Fatalf("getUpdateInfo: %s", err)
|
||||
}
|
||||
|
||||
u := updateInfo{
|
||||
pkgURL: "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
|
||||
newVer: newver,
|
||||
updateDir: Context.workDir + "/agh-update-" + newver,
|
||||
backupDir: Context.workDir + "/agh-backup",
|
||||
configName: Context.workDir + "/AdGuardHome.yaml",
|
||||
updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome.yaml",
|
||||
curBinName: Context.workDir + "/AdGuardHome",
|
||||
bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
|
||||
newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome",
|
||||
}
|
||||
|
||||
assert.Equal(t, uu.pkgURL, u.pkgURL)
|
||||
assert.Equal(t, uu.pkgName, u.pkgName)
|
||||
assert.Equal(t, uu.newVer, u.newVer)
|
||||
assert.Equal(t, uu.updateDir, u.updateDir)
|
||||
assert.Equal(t, uu.backupDir, u.backupDir)
|
||||
assert.Equal(t, uu.configName, u.configName)
|
||||
assert.Equal(t, uu.updateConfigName, u.updateConfigName)
|
||||
assert.Equal(t, uu.curBinName, u.curBinName)
|
||||
assert.Equal(t, uu.bkpBinName, u.bkpBinName)
|
||||
assert.Equal(t, uu.newBinName, u.newBinName)
|
||||
|
||||
e := doUpdate(&u)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
os.RemoveAll(u.backupDir)
|
||||
}
|
||||
|
||||
func TestTargzFileUnpack(t *testing.T) {
|
||||
fn := "../dist/AdGuardHome_linux_amd64.tar.gz"
|
||||
outdir := "../test-unpack"
|
||||
defer os.RemoveAll(outdir)
|
||||
_ = os.Mkdir(outdir, 0755)
|
||||
files, e := targzFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
}
|
||||
|
||||
func TestZipFileUnpack(t *testing.T) {
|
||||
fn := "../dist/AdGuardHome_windows_amd64.zip"
|
||||
outdir := "../test-unpack"
|
||||
_ = os.Mkdir(outdir, 0755)
|
||||
files, e := zipFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
os.RemoveAll(outdir)
|
||||
}
|
||||
388
internal/home/dns.go
Normal file
388
internal/home/dns.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
func onConfigModified() {
|
||||
_ = config.write()
|
||||
}
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer() error {
|
||||
var err error
|
||||
baseDir := Context.getDataDir()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
LimitDays: config.DNS.StatsInterval,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
}
|
||||
Context.stats, err = stats.New(statsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't initialize statistics module")
|
||||
}
|
||||
conf := querylog.Config{
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
BaseDir: baseDir,
|
||||
Interval: config.DNS.QueryLogInterval,
|
||||
MemSize: config.DNS.QueryLogMemSize,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
filterConf := config.DNS.DnsfilterConf
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
bindhost = "127.0.0.1"
|
||||
}
|
||||
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||
filterConf.AutoHosts = &Context.autoHosts
|
||||
filterConf.ConfigModified = onConfigModified
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
Context.dnsFilter = dnsfilter.New(&filterConf, nil)
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.dnsFilter,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
}
|
||||
Context.dnsServer = dnsforward.NewServer(p)
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
dnsConfig := generateServerConfig()
|
||||
err = Context.dnsServer.Prepare(&dnsConfig)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
return fmt.Errorf("dnsServer.Prepare: %s", err)
|
||||
}
|
||||
|
||||
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
|
||||
Context.whois = initWhois(&Context.clients)
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
// nolint (gocyclo)
|
||||
// Return TRUE if IP is within public Internet IP range
|
||||
func isPublicIP(ip net.IP) bool {
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
switch ip4[0] {
|
||||
case 0:
|
||||
return false //software
|
||||
case 10:
|
||||
return false //private network
|
||||
case 127:
|
||||
return false //loopback
|
||||
case 169:
|
||||
if ip4[1] == 254 {
|
||||
return false //link-local
|
||||
}
|
||||
case 172:
|
||||
if ip4[1] >= 16 && ip4[1] <= 31 {
|
||||
return false //private network
|
||||
}
|
||||
case 192:
|
||||
if (ip4[1] == 0 && ip4[2] == 0) || //private network
|
||||
(ip4[1] == 0 && ip4[2] == 2) || //documentation
|
||||
(ip4[1] == 88 && ip4[2] == 99) || //reserved
|
||||
(ip4[1] == 168) { //private network
|
||||
return false
|
||||
}
|
||||
case 198:
|
||||
if (ip4[1] == 18 || ip4[2] == 19) || //private network
|
||||
(ip4[1] == 51 || ip4[2] == 100) { //documentation
|
||||
return false
|
||||
}
|
||||
case 203:
|
||||
if ip4[1] == 0 && ip4[2] == 113 { //documentation
|
||||
return false
|
||||
}
|
||||
case 224:
|
||||
if ip4[1] == 0 && ip4[2] == 0 { //multicast
|
||||
return false
|
||||
}
|
||||
case 255:
|
||||
if ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 255 { //subnet
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := dnsforward.GetIPString(d.Addr)
|
||||
if ip == "" {
|
||||
// This would be quite weird if we get here
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if !ipAddr.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if isPublicIP(ipAddr) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func generateServerConfig() dnsforward.ServerConfig {
|
||||
newconfig := dnsforward.ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||
FilteringConfig: config.DNS.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
OnDNSRequest: onDNSRequest,
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newconfig.TLSConfig = tlsConf.TLSConfig
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newconfig.TLSListenAddr = &net.TCPAddr{
|
||||
IP: net.ParseIP(config.DNS.BindHost),
|
||||
Port: tlsConf.PortDNSOverTLS,
|
||||
}
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
newconfig.QUICListenAddr = &net.UDPAddr{
|
||||
IP: net.ParseIP(config.DNS.BindHost),
|
||||
Port: int(tlsConf.PortDNSOverQUIC),
|
||||
}
|
||||
}
|
||||
}
|
||||
newconfig.TLSv12Roots = Context.tlsRoots
|
||||
newconfig.TLSCiphers = Context.tlsCiphers
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
return newconfig
|
||||
}
|
||||
|
||||
type DNSEncryption struct {
|
||||
https string
|
||||
tls string
|
||||
quic string
|
||||
}
|
||||
|
||||
func getDNSEncryption() DNSEncryption {
|
||||
dnsEncryption := DNSEncryption{}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
|
||||
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsEncryption.https = addr
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||
dnsEncryption.tls = addr
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC)
|
||||
dnsEncryption.quic = addr
|
||||
}
|
||||
}
|
||||
|
||||
return dnsEncryption
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
return []string{}
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
addDNSAddress(&dnsAddresses, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||
}
|
||||
|
||||
dnsEncryption := getDNSEncryption()
|
||||
if dnsEncryption.https != "" {
|
||||
dnsAddresses = append(dnsAddresses, dnsEncryption.https)
|
||||
}
|
||||
if dnsEncryption.tls != "" {
|
||||
dnsAddresses = append(dnsAddresses, dnsEncryption.tls)
|
||||
}
|
||||
if dnsEncryption.quic != "" {
|
||||
dnsAddresses = append(dnsAddresses, dnsEncryption.quic)
|
||||
}
|
||||
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
if len(clientAddr) == 0 {
|
||||
return
|
||||
}
|
||||
setts.ClientIP = clientAddr
|
||||
|
||||
c, ok := Context.clients.Find(clientAddr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
setts.ClientTags = c.Tags
|
||||
|
||||
if !c.UseOwnSettings {
|
||||
return
|
||||
}
|
||||
|
||||
setts.FilteringEnabled = c.FilteringEnabled
|
||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
|
||||
func startDNSServer() error {
|
||||
if isRunning() {
|
||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
enableFilters(false)
|
||||
|
||||
Context.clients.Start()
|
||||
|
||||
err := Context.dnsServer.Start()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
Context.dnsFilter.Start()
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
|
||||
for _, ip := range topClients {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if !ipAddr.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if isPublicIP(ipAddr) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func reconfigureDNSServer() error {
|
||||
newconfig := generateServerConfig()
|
||||
err := Context.dnsServer.Reconfigure(&newconfig)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopDNSServer() error {
|
||||
if !isRunning() {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := Context.dnsServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||
}
|
||||
|
||||
closeDNSServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeDNSServer() {
|
||||
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
||||
if Context.dnsServer != nil {
|
||||
Context.dnsServer.Close()
|
||||
Context.dnsServer = nil
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
Context.dnsFilter.Close()
|
||||
Context.dnsFilter = nil
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
Context.stats.Close()
|
||||
Context.stats = nil
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
Context.queryLog.Close()
|
||||
Context.queryLog = nil
|
||||
}
|
||||
|
||||
Context.filters.Close()
|
||||
|
||||
log.Debug("Closed all DNS modules")
|
||||
}
|
||||
713
internal/home/filter.go
Normal file
713
internal/home/filter.go
Normal file
@@ -0,0 +1,713 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
var (
|
||||
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
||||
)
|
||||
|
||||
// Filtering - module object
|
||||
type Filtering struct {
|
||||
// conf FilteringConf
|
||||
refreshStatus uint32 // 0:none; 1:in progress
|
||||
refreshLock sync.Mutex
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
}
|
||||
|
||||
// Init - initialize the module
|
||||
func (f *Filtering) Init() {
|
||||
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0755)
|
||||
f.loadFilters(config.Filters)
|
||||
f.loadFilters(config.WhitelistFilters)
|
||||
deduplicateFilters()
|
||||
updateUniqueFilterID(config.Filters)
|
||||
updateUniqueFilterID(config.WhitelistFilters)
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (f *Filtering) Start() {
|
||||
f.RegisterFilteringHandlers()
|
||||
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go f.periodicallyRefreshFilters()
|
||||
}
|
||||
|
||||
// Close - close the module
|
||||
func (f *Filtering) Close() {
|
||||
}
|
||||
|
||||
func defaultFilters() []filter {
|
||||
return []filter{
|
||||
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
|
||||
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
|
||||
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
|
||||
}
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
Enabled bool
|
||||
URL string // URL or a file path
|
||||
Name string `yaml:"name"`
|
||||
RulesCount int `yaml:"-"`
|
||||
LastUpdated time.Time `yaml:"-"`
|
||||
checksum uint32 // checksum of the file data
|
||||
white bool
|
||||
|
||||
dnsfilter.Filter `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Creates a helper object for working with the user rules
|
||||
func userFilter() filter {
|
||||
f := filter{
|
||||
// User filter always has constant ID=0
|
||||
Enabled: true,
|
||||
}
|
||||
f.Filter.Data = []byte(strings.Join(config.UserRules, "\n"))
|
||||
return f
|
||||
}
|
||||
|
||||
const (
|
||||
statusFound = 1
|
||||
statusEnabledChanged = 2
|
||||
statusURLChanged = 4
|
||||
statusURLExists = 8
|
||||
statusUpdateRequired = 0x10
|
||||
)
|
||||
|
||||
// Update properties for a filter specified by its URL
|
||||
// Return status* flags.
|
||||
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
|
||||
r := 0
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
filters := &config.Filters
|
||||
if whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
}
|
||||
|
||||
for i := range *filters {
|
||||
filt := &(*filters)[i]
|
||||
if filt.URL != url {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("filter: set properties: %s: {%s %s %v}",
|
||||
filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||
filt.Name = newf.Name
|
||||
|
||||
if filt.URL != newf.URL {
|
||||
r |= statusURLChanged | statusUpdateRequired
|
||||
if filterExistsNoLock(newf.URL) {
|
||||
return statusURLExists
|
||||
}
|
||||
filt.URL = newf.URL
|
||||
filt.unload()
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
}
|
||||
|
||||
if filt.Enabled != newf.Enabled {
|
||||
r |= statusEnabledChanged
|
||||
filt.Enabled = newf.Enabled
|
||||
if filt.Enabled {
|
||||
if (r & statusURLChanged) == 0 {
|
||||
e := f.load(filt)
|
||||
if e != nil {
|
||||
// This isn't a fatal error,
|
||||
// because it may occur when someone removes the file from disk.
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
r |= statusUpdateRequired
|
||||
}
|
||||
}
|
||||
} else {
|
||||
filt.unload()
|
||||
}
|
||||
}
|
||||
|
||||
return r | statusFound
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Return TRUE if a filter with this URL exists
|
||||
func filterExists(url string) bool {
|
||||
config.RLock()
|
||||
r := filterExistsNoLock(url)
|
||||
config.RUnlock()
|
||||
return r
|
||||
}
|
||||
|
||||
func filterExistsNoLock(url string) bool {
|
||||
for _, f := range config.Filters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Add a filter
|
||||
// Return FALSE if a filter with this URL exists
|
||||
func filterAdd(f filter) bool {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
// Check for duplicates
|
||||
if filterExistsNoLock(f.URL) {
|
||||
return false
|
||||
}
|
||||
|
||||
if f.white {
|
||||
config.WhitelistFilters = append(config.WhitelistFilters, f)
|
||||
} else {
|
||||
config.Filters = append(config.Filters, f)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Load filters from the disk
|
||||
// And if any filter has zero ID, assign a new one
|
||||
func (f *Filtering) loadFilters(array []filter) {
|
||||
for i := range array {
|
||||
filter := &array[i] // otherwise we're operating on a copy
|
||||
if filter.ID == 0 {
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
|
||||
if !filter.Enabled {
|
||||
// No need to load a filter that is not enabled
|
||||
continue
|
||||
}
|
||||
|
||||
err := f.load(filter)
|
||||
if err != nil {
|
||||
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func deduplicateFilters() {
|
||||
// Deduplicate filters
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID++
|
||||
return value
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func (f *Filtering) periodicallyRefreshFilters() {
|
||||
const maxInterval = 1 * 60 * 60
|
||||
intval := 5 // use a dynamically increasing time interval
|
||||
for {
|
||||
isNetworkErr := false
|
||||
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
|
||||
f.refreshLock.Lock()
|
||||
_, isNetworkErr = f.refreshFiltersIfNecessary(FilterRefreshBlocklists | FilterRefreshAllowlists)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
if !isNetworkErr {
|
||||
intval = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
if isNetworkErr {
|
||||
intval *= 2
|
||||
if intval > maxInterval {
|
||||
intval = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(intval) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh filters
|
||||
// flags: FilterRefresh*
|
||||
// important:
|
||||
// TRUE: ignore the fact that we're currently updating the filters
|
||||
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
|
||||
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
|
||||
if !important && !set {
|
||||
return 0, fmt.Errorf("filters update procedure is already running")
|
||||
}
|
||||
|
||||
f.refreshLock.Lock()
|
||||
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
return nUpdated, nil
|
||||
}
|
||||
|
||||
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
|
||||
var updateFilters []filter
|
||||
var updateFlags []bool // 'true' if filter data has changed
|
||||
|
||||
now := time.Now()
|
||||
config.RLock()
|
||||
for i := range *filters {
|
||||
f := &(*filters)[i] // otherwise we will be operating on a copy
|
||||
|
||||
if !f.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
|
||||
if !force && expireTime > now.Unix() {
|
||||
continue
|
||||
}
|
||||
|
||||
var uf filter
|
||||
uf.ID = f.ID
|
||||
uf.URL = f.URL
|
||||
uf.Name = f.Name
|
||||
uf.checksum = f.checksum
|
||||
updateFilters = append(updateFilters, uf)
|
||||
}
|
||||
config.RUnlock()
|
||||
|
||||
if len(updateFilters) == 0 {
|
||||
return 0, nil, nil, false
|
||||
}
|
||||
|
||||
nfail := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated, err := f.update(uf)
|
||||
updateFlags = append(updateFlags, updated)
|
||||
if err != nil {
|
||||
nfail++
|
||||
log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if nfail == len(updateFilters) {
|
||||
return 0, nil, nil, true
|
||||
}
|
||||
|
||||
updateCount := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
|
||||
config.Lock()
|
||||
for k := range *filters {
|
||||
f := &(*filters)[k]
|
||||
if f.ID != uf.ID || f.URL != uf.URL {
|
||||
continue
|
||||
}
|
||||
f.LastUpdated = uf.LastUpdated
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Updated filter #%d. Rules: %d -> %d",
|
||||
f.ID, f.RulesCount, uf.RulesCount)
|
||||
f.Name = uf.Name
|
||||
f.RulesCount = uf.RulesCount
|
||||
f.checksum = uf.checksum
|
||||
updateCount++
|
||||
}
|
||||
config.Unlock()
|
||||
}
|
||||
|
||||
return updateCount, updateFilters, updateFlags, false
|
||||
}
|
||||
|
||||
const (
|
||||
FilterRefreshForce = 1 // ignore last file modification date
|
||||
FilterRefreshAllowlists = 2 // update allow-lists
|
||||
FilterRefreshBlocklists = 4 // update block-lists
|
||||
)
|
||||
|
||||
// Checks filters updates if necessary
|
||||
// If force is true, it ignores the filter.LastUpdated field value
|
||||
// flags: FilterRefresh*
|
||||
//
|
||||
// Algorithm:
|
||||
// . Get the list of filters to be updated
|
||||
// . For each filter run the download and checksum check operation
|
||||
// . Store downloaded data in a temporary file inside data/filters directory
|
||||
// . For each filter:
|
||||
// . If filter data hasn't changed, just set new update time on file
|
||||
// . If filter data has changed:
|
||||
// . rename the temporary file (<temp> -> 1.txt)
|
||||
// Note that this method works only on UNIX.
|
||||
// On Windows we don't pass files to dnsfilter - we pass the whole data.
|
||||
// . Pass new filters to dnsfilter object - it analyzes new data while the old filters are still active
|
||||
// . dnsfilter activates new filters
|
||||
//
|
||||
// Return the number of updated filters
|
||||
// Return TRUE - there was a network error and nothing could be updated
|
||||
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
|
||||
log.Debug("Filters: updating...")
|
||||
|
||||
updateCount := 0
|
||||
var updateFilters []filter
|
||||
var updateFlags []bool
|
||||
netError := false
|
||||
netErrorW := false
|
||||
force := false
|
||||
if (flags & FilterRefreshForce) != 0 {
|
||||
force = true
|
||||
}
|
||||
if (flags & FilterRefreshBlocklists) != 0 {
|
||||
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
|
||||
}
|
||||
if (flags & FilterRefreshAllowlists) != 0 {
|
||||
updateCountW := 0
|
||||
var updateFiltersW []filter
|
||||
var updateFlagsW []bool
|
||||
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
|
||||
updateCount += updateCountW
|
||||
updateFilters = append(updateFilters, updateFiltersW...)
|
||||
updateFlags = append(updateFlags, updateFlagsW...)
|
||||
}
|
||||
if netError && netErrorW {
|
||||
return 0, true
|
||||
}
|
||||
|
||||
if updateCount != 0 {
|
||||
enableFilters(false)
|
||||
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
_ = os.Remove(uf.Path() + ".old")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Filters: update finished")
|
||||
return updateCount, false
|
||||
}
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
func isPrintableText(data []byte, len int) bool {
|
||||
for i := 0; i < len; i++ {
|
||||
c := data[i]
|
||||
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
r := bufio.NewReader(file)
|
||||
checksum := uint32(0)
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
//
|
||||
|
||||
} else if line[0] == '!' {
|
||||
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
|
||||
} else if line[0] == '#' {
|
||||
//
|
||||
|
||||
} else {
|
||||
rulesCount++
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, checksum, name
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
func (f *Filtering) update(filter *filter) (bool, error) {
|
||||
b, err := f.updateIntl(filter)
|
||||
filter.LastUpdated = time.Now()
|
||||
if !b {
|
||||
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
|
||||
if e != nil {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
// nolint(gocyclo)
|
||||
func (f *Filtering) updateIntl(filter *filter) (bool, error) {
|
||||
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||
|
||||
tmpFile, err := ioutil.TempFile(filepath.Join(Context.getDataDir(), filterDir), "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if tmpFile != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
}()
|
||||
|
||||
var reader io.Reader
|
||||
if filepath.IsAbs(filter.URL) {
|
||||
f, err := os.Open(filter.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
} else {
|
||||
resp, err := Context.client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
}
|
||||
reader = resp.Body
|
||||
}
|
||||
|
||||
htmlTest := true
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
buf := make([]byte, 64*1024)
|
||||
total := 0
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
total += n
|
||||
|
||||
if htmlTest {
|
||||
// gather full buffer firstChunk and perform its data tests
|
||||
num := util.MinInt(n, len(firstChunk)-firstChunkLen)
|
||||
copied := copy(firstChunk[firstChunkLen:], buf[:num])
|
||||
firstChunkLen += copied
|
||||
|
||||
if firstChunkLen == len(firstChunk) || err == io.EOF {
|
||||
if !isPrintableText(firstChunk, firstChunkLen) {
|
||||
return false, fmt.Errorf("data contains non-printable characters")
|
||||
}
|
||||
|
||||
s := strings.ToLower(string(firstChunk))
|
||||
if strings.Index(s, "<html") >= 0 ||
|
||||
strings.Index(s, "<!doctype") >= 0 {
|
||||
return false, fmt.Errorf("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
htmlTest = false
|
||||
firstChunk = nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err2 := tmpFile.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
return false, err2
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract filter name and count number of rules
|
||||
_, _ = tmpFile.Seek(0, io.SeekStart)
|
||||
rulesCount, checksum, filterName := f.parseFilterContents(tmpFile)
|
||||
// Check if the filter has been really changed
|
||||
if filter.checksum == checksum {
|
||||
log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Filter %d has been updated: %d bytes, %d rules",
|
||||
filter.ID, total, rulesCount)
|
||||
if len(filter.Name) == 0 {
|
||||
filter.Name = filterName
|
||||
}
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
// Closing the file before renaming it is necessary on Windows
|
||||
_ = tmpFile.Close()
|
||||
err = os.Rename(tmpFile.Name(), filterFilePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
tmpFile = nil
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (f *Filtering) load(filter *filter) error {
|
||||
filterFilePath := filter.Path()
|
||||
log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.Open(filterFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
st, _ := file.Stat()
|
||||
|
||||
log.Tracef("File %s, id %d, length %d",
|
||||
filterFilePath, filter.ID, st.Size())
|
||||
rulesCount, checksum, _ := f.parseFilterContents(file)
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
filter.LastUpdated = filter.LastTimeUpdated()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *filter) unload() {
|
||||
filter.RulesCount = 0
|
||||
filter.checksum = 0
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
// LastTimeUpdated returns the time when the filter was last time updated
|
||||
func (filter *filter) LastTimeUpdated() time.Time {
|
||||
filterFilePath := filter.Path()
|
||||
s, err := os.Stat(filterFilePath)
|
||||
if os.IsNotExist(err) {
|
||||
// if the filter file does not exist, return 0001-01-01
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// if the filter file does not exist, return 0001-01-01
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// filter file modified time
|
||||
return s.ModTime()
|
||||
}
|
||||
|
||||
func enableFilters(async bool) {
|
||||
var filters []dnsfilter.Filter
|
||||
var whiteFilters []dnsfilter.Filter
|
||||
if config.DNS.FilteringEnabled {
|
||||
// convert array of filters
|
||||
|
||||
userFilter := userFilter()
|
||||
f := dnsfilter.Filter{
|
||||
ID: userFilter.ID,
|
||||
Data: userFilter.Data,
|
||||
}
|
||||
filters = append(filters, f)
|
||||
|
||||
for _, filter := range config.Filters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
f := dnsfilter.Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
}
|
||||
filters = append(filters, f)
|
||||
}
|
||||
for _, filter := range config.WhitelistFilters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
f := dnsfilter.Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
}
|
||||
whiteFilters = append(whiteFilters, f)
|
||||
}
|
||||
}
|
||||
|
||||
_ = Context.dnsFilter.SetFilters(filters, whiteFilters, async)
|
||||
}
|
||||
65
internal/home/filter_test.go
Normal file
65
internal/home/filter_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testStartFilterListener() net.Listener {
|
||||
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
`
|
||||
_, _ = w.Write([]byte(content))
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() { _ = http.Serve(listener, nil) }()
|
||||
return listener
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
l := testStartFilterListener()
|
||||
defer func() { _ = l.Close() }()
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
Context = homeContext{}
|
||||
Context.workDir = dir
|
||||
Context.client = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
Context.filters.Init()
|
||||
|
||||
f := filter{
|
||||
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
|
||||
}
|
||||
|
||||
// download
|
||||
ok, err := Context.filters.update(&f)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 3, f.RulesCount)
|
||||
|
||||
// refresh
|
||||
ok, err = Context.filters.update(&f)
|
||||
assert.True(t, !ok && err == nil)
|
||||
|
||||
err = Context.filters.load(&f)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
f.unload()
|
||||
_ = os.Remove(f.Path())
|
||||
}
|
||||
668
internal/home/home.go
Normal file
668
internal/home/home.go
Normal file
@@ -0,0 +1,668 @@
|
||||
// Package home contains AdGuard Home's HTTP API methods.
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/update"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/joomcode/errorx"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/isdelve"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
||||
configSyslog = "syslog"
|
||||
)
|
||||
|
||||
// Update-related variables
|
||||
var (
|
||||
versionString = "dev"
|
||||
updateChannel = "none"
|
||||
versionCheckURL = ""
|
||||
ARMVersion = ""
|
||||
)
|
||||
|
||||
// Global context
|
||||
type homeContext struct {
|
||||
// Modules
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *Whois // WHOIS module
|
||||
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
||||
updater *update.Updater
|
||||
|
||||
// Runtime properties
|
||||
// --
|
||||
|
||||
configFilename string // Config filename (can be overridden via the command line arguments)
|
||||
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
tlsCiphers []uint16 // list of TLS ciphers to use
|
||||
transport *http.Transport
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
}
|
||||
|
||||
// getDataDir returns path to the directory where we store databases and filters
|
||||
func (c *homeContext) getDataDir() string {
|
||||
return filepath.Join(c.workDir, dataDir)
|
||||
}
|
||||
|
||||
// Context - a global context object
|
||||
var Context homeContext
|
||||
|
||||
// Main is the entry point
|
||||
func Main(version string, channel string, armVer string) {
|
||||
// Init update-related global variables
|
||||
versionString = version
|
||||
updateChannel = channel
|
||||
ARMVersion = armVer
|
||||
versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
|
||||
|
||||
// config can be specified, which reads options from there, but other command line flags have to override config values
|
||||
// therefore, we must do it manually instead of using a lib
|
||||
args := loadOptions()
|
||||
|
||||
Context.appSignalChannel = make(chan os.Signal)
|
||||
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
go func() {
|
||||
for {
|
||||
sig := <-Context.appSignalChannel
|
||||
log.Info("Received signal '%s'", sig)
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
Context.clients.Reload()
|
||||
Context.tls.Reload()
|
||||
|
||||
default:
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if args.serviceControlAction != "" {
|
||||
handleServiceControlAction(args)
|
||||
return
|
||||
}
|
||||
|
||||
// run the protection
|
||||
run(args)
|
||||
}
|
||||
|
||||
// version - returns the current version string
|
||||
func version() string {
|
||||
msg := "AdGuard Home, version %s, channel %s, arch %s %s"
|
||||
if ARMVersion != "" {
|
||||
msg = msg + " v" + ARMVersion
|
||||
}
|
||||
return fmt.Sprintf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// run initializes configuration and runs the AdGuard Home
|
||||
// run is a blocking method!
|
||||
// nolint
|
||||
func run(args options) {
|
||||
// configure config filename
|
||||
initConfigFilename(args)
|
||||
|
||||
// configure working dir and config path
|
||||
initWorkingDir(args)
|
||||
|
||||
// configure log level and output
|
||||
configureLogger(args)
|
||||
|
||||
// Go memory hacks
|
||||
memoryUsage(args)
|
||||
|
||||
// print the first message after logger is configured
|
||||
log.Println(version())
|
||||
log.Debug("Current working directory is %s", Context.workDir)
|
||||
if args.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
}
|
||||
Context.runningAsService = args.runningAsService
|
||||
Context.disableUpdate = args.disableUpdate
|
||||
|
||||
Context.firstRun = detectFirstRun()
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
}
|
||||
|
||||
initConfig()
|
||||
|
||||
Context.tlsRoots = util.LoadSystemRootCAs()
|
||||
Context.tlsCiphers = util.InitTLSCiphers()
|
||||
Context.transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: Context.tlsRoots,
|
||||
},
|
||||
}
|
||||
Context.client = &http.Client{
|
||||
Timeout: time.Minute * 5,
|
||||
Transport: Context.transport,
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
// Do the upgrade if necessary
|
||||
err := upgradeConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = parseConfig()
|
||||
if err != nil {
|
||||
log.Error("Failed to parse configuration, exiting")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if args.checkConfig {
|
||||
log.Info("Configuration file is OK")
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// 'clients' module uses 'dnsfilter' module's static data (dnsfilter.BlockedSvcKnown()),
|
||||
// so we have to initialize dnsfilter's static data first,
|
||||
// but also avoid relying on automatic Go init() function
|
||||
dnsfilter.InitModule()
|
||||
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
if runtime.GOOS != "windows" {
|
||||
Context.dhcpServer = dhcpd.Create(config.DHCP)
|
||||
if Context.dhcpServer == nil {
|
||||
log.Fatalf("Can't initialize DHCP module")
|
||||
}
|
||||
}
|
||||
Context.autoHosts.Init("")
|
||||
|
||||
Context.updater = update.NewUpdater(update.Config{
|
||||
Client: Context.client,
|
||||
WorkDir: Context.workDir,
|
||||
VersionURL: versionCheckURL,
|
||||
VersionString: versionString,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
ARMVersion: ARMVersion,
|
||||
ConfigName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts)
|
||||
config.Clients = nil
|
||||
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||
config.RlimitNoFile != 0 {
|
||||
util.SetRlimit(config.RlimitNoFile)
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
if args.bindHost != "" {
|
||||
config.BindHost = args.bindHost
|
||||
}
|
||||
if args.bindPort != 0 {
|
||||
config.BindPort = args.bindPort
|
||||
}
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
Context.pidFileName = args.pidFile
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if config.DebugPProf {
|
||||
mux := http.NewServeMux()
|
||||
util.PProfRegisterWebHandlers(mux)
|
||||
go func() {
|
||||
log.Info("pprof: listening on localhost:6060")
|
||||
err := http.ListenAndServe("localhost:6060", mux)
|
||||
log.Error("Error while running the pprof server: %s", err)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
err := os.MkdirAll(Context.getDataDir(), 0755)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
|
||||
}
|
||||
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
GLMode = args.glinetMode
|
||||
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
|
||||
if Context.auth == nil {
|
||||
log.Fatalf("Couldn't initialize Auth module")
|
||||
}
|
||||
config.Users = nil
|
||||
|
||||
Context.tls = tlsCreate(config.TLS)
|
||||
if Context.tls == nil {
|
||||
log.Fatalf("Can't initialize TLS module")
|
||||
}
|
||||
|
||||
webConf := WebConfig{
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
}
|
||||
Context.web = CreateWeb(&webConf)
|
||||
if Context.web == nil {
|
||||
log.Fatalf("Can't initialize Web module")
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
log.Fatalf("%s", err)
|
||||
}
|
||||
Context.tls.Start()
|
||||
Context.autoHosts.Start()
|
||||
|
||||
go func() {
|
||||
err := startDNSServer()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
_ = Context.dhcpServer.Start()
|
||||
}
|
||||
}
|
||||
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
// StartMods - initialize and start DNS after installation
|
||||
func StartMods() error {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Context.tls.Start()
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the current user permissions are enough to run AdGuard Home
|
||||
func checkPermissions() {
|
||||
log.Info("Checking if AdGuard Home has necessary permissions")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows we need to have admin rights to run properly
|
||||
|
||||
admin, _ := util.HaveAdminRights()
|
||||
if //noinspection ALL
|
||||
admin || isdelve.Enabled {
|
||||
// Don't forget that for this to work you need to add "delve" tag explicitly
|
||||
// https://stackoverflow.com/questions/47879070/how-can-i-see-if-the-goland-debugger-is-running-in-the-program
|
||||
return
|
||||
}
|
||||
|
||||
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
|
||||
}
|
||||
|
||||
// We should check if AdGuard Home is able to bind to port 53
|
||||
ok, err := util.CanBindPort(53)
|
||||
|
||||
if ok {
|
||||
log.Info("AdGuard Home can bind to port 53")
|
||||
return
|
||||
}
|
||||
|
||||
if opErr, ok := err.(*net.OpError); ok {
|
||||
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
|
||||
if errno, ok := sysErr.Err.(syscall.Errno); ok && errno == syscall.EACCES {
|
||||
msg := `Permission check failed.
|
||||
|
||||
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
|
||||
Please note, that this is crucial for a server to be able to use privileged ports.
|
||||
|
||||
You have two options:
|
||||
1. Run AdGuard Home with root privileges
|
||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
|
||||
|
||||
log.Fatal(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
|
||||
|
||||
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
|
||||
|
||||
log.Info(msg)
|
||||
}
|
||||
|
||||
// Write PID to a file
|
||||
func writePIDFile(fn string) bool {
|
||||
data := fmt.Sprintf("%d", os.Getpid())
|
||||
err := ioutil.WriteFile(fn, []byte(data), 0644)
|
||||
if err != nil {
|
||||
log.Error("Couldn't write PID to file %s: %v", fn, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func initConfigFilename(args options) {
|
||||
// config file path can be overridden by command-line arguments:
|
||||
if args.configFilename != "" {
|
||||
Context.configFilename = args.configFilename
|
||||
} else {
|
||||
// Default config file name
|
||||
Context.configFilename = "AdGuardHome.yaml"
|
||||
}
|
||||
}
|
||||
|
||||
// initWorkingDir initializes the workDir
|
||||
// if no command-line arguments specified, we use the directory where our binary file is located
|
||||
func initWorkingDir(args options) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if args.workDir != "" {
|
||||
// If there is a custom config file, use it's directory as our working dir
|
||||
Context.workDir = args.workDir
|
||||
} else {
|
||||
Context.workDir = filepath.Dir(execPath)
|
||||
}
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output
|
||||
func configureLogger(args options) {
|
||||
ls := getLogSettings()
|
||||
|
||||
// command-line arguments can override config settings
|
||||
if args.verbose || config.Verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.LogFile = args.logFile
|
||||
} else if config.LogFile != "" {
|
||||
ls.LogFile = config.LogFile
|
||||
}
|
||||
|
||||
// Handle default log settings overrides
|
||||
ls.LogCompress = config.LogCompress
|
||||
ls.LogLocalTime = config.LogLocalTime
|
||||
ls.LogMaxBackups = config.LogMaxBackups
|
||||
ls.LogMaxSize = config.LogMaxSize
|
||||
ls.LogMaxAge = config.LogMaxAge
|
||||
|
||||
// log.SetLevel(log.INFO) - default
|
||||
if ls.Verbose {
|
||||
log.SetLevel(log.DEBUG)
|
||||
}
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing else is configured
|
||||
// Otherwise, we'll simply loose the log output
|
||||
ls.LogFile = configSyslog
|
||||
}
|
||||
|
||||
// logs are written to stdout (default)
|
||||
if ls.LogFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ls.LogFile == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows
|
||||
err := util.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
|
||||
if filepath.IsAbs(ls.LogFile) {
|
||||
logFilePath = ls.LogFile
|
||||
}
|
||||
|
||||
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot create a log file: %s", err)
|
||||
}
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.LogCompress, // disabled by default
|
||||
LocalTime: ls.LogLocalTime,
|
||||
MaxBackups: ls.LogMaxBackups,
|
||||
MaxSize: ls.LogMaxSize, // megabytes
|
||||
MaxAge: ls.LogMaxAge, //days
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
log.Info("Stopping AdGuard Home")
|
||||
|
||||
if Context.web != nil {
|
||||
Context.web.Close()
|
||||
Context.web = nil
|
||||
}
|
||||
if Context.auth != nil {
|
||||
Context.auth.Close()
|
||||
Context.auth = nil
|
||||
}
|
||||
|
||||
err := stopDNSServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DNS server: %s", err)
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
Context.dhcpServer.Stop()
|
||||
}
|
||||
|
||||
Context.autoHosts.Close()
|
||||
|
||||
if Context.tls != nil {
|
||||
Context.tls.Close()
|
||||
Context.tls = nil
|
||||
}
|
||||
}
|
||||
|
||||
// This function is called before application exits
|
||||
func cleanupAlways() {
|
||||
if len(Context.pidFileName) != 0 {
|
||||
_ = os.Remove(Context.pidFileName)
|
||||
}
|
||||
log.Info("Stopped")
|
||||
}
|
||||
|
||||
func exitWithError() {
|
||||
os.Exit(64)
|
||||
}
|
||||
|
||||
// loadOptions reads command line arguments and initializes configuration
|
||||
func loadOptions() options {
|
||||
o, f, err := parse(os.Args[0], os.Args[1:])
|
||||
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
_ = printHelp(os.Args[0])
|
||||
exitWithError()
|
||||
} else if f != nil {
|
||||
err = f()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
exitWithError()
|
||||
} else {
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
// prints IP addresses which user can use to open the admin interface
|
||||
// proto is either "http" or "https"
|
||||
func printHTTPAddresses(proto string) {
|
||||
var address string
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
if Context.tls != nil {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := strconv.Itoa(config.BindPort)
|
||||
if proto == "https" {
|
||||
port = strconv.Itoa(tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if proto == "https" && tlsConf.ServerName != "" {
|
||||
if tlsConf.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", tlsConf.ServerName)
|
||||
} else {
|
||||
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
|
||||
}
|
||||
} else if config.BindHost == "0.0.0.0" {
|
||||
log.Println("AdGuard Home is available on the following addresses:")
|
||||
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
// That's weird, but we'll ignore it
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
address = net.JoinHostPort(addr, strconv.Itoa(config.BindPort))
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------
|
||||
// first run / install
|
||||
// -------------------
|
||||
func detectFirstRun() bool {
|
||||
configfile := Context.configFilename
|
||||
if !filepath.IsAbs(configfile) {
|
||||
configfile = filepath.Join(Context.workDir, Context.configFilename)
|
||||
}
|
||||
_, err := os.Stat(configfile)
|
||||
if !os.IsNotExist(err) {
|
||||
// do nothing, file exists
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
addrs, e := Context.dnsServer.Resolve(host)
|
||||
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
||||
}
|
||||
|
||||
var dialErrs []error
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
dialErrs = append(dialErrs, err)
|
||||
continue
|
||||
}
|
||||
return con, err
|
||||
}
|
||||
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||
}
|
||||
|
||||
func getHTTPProxy(req *http.Request) (*url.URL, error) {
|
||||
if len(config.ProxyURL) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return url.Parse(config.ProxyURL)
|
||||
}
|
||||
189
internal/home/home_test.go
Normal file
189
internal/home/home_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
// +build !race
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxyutil"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const yamlConf = `bind_host: 127.0.0.1
|
||||
bind_port: 3000
|
||||
users: []
|
||||
language: en
|
||||
rlimit_nofile: 0
|
||||
web_session_ttl: 720
|
||||
dns:
|
||||
bind_host: 127.0.0.1
|
||||
port: 5354
|
||||
statistics_interval: 90
|
||||
querylog_enabled: true
|
||||
querylog_interval: 90
|
||||
querylog_size_memory: 0
|
||||
protection_enabled: true
|
||||
blocking_mode: null_ip
|
||||
blocked_response_ttl: 0
|
||||
ratelimit: 100
|
||||
ratelimit_whitelist: []
|
||||
refuse_any: false
|
||||
bootstrap_dns:
|
||||
- 1.1.1.1:53
|
||||
all_servers: false
|
||||
allowed_clients: []
|
||||
disallowed_clients: []
|
||||
blocked_hosts: []
|
||||
parental_block_host: family-block.dns.adguard.com
|
||||
safebrowsing_block_host: standard-block.dns.adguard.com
|
||||
cache_size: 0
|
||||
upstream_dns:
|
||||
- https://1.1.1.1/dns-query
|
||||
filtering_enabled: true
|
||||
filters_update_interval: 168
|
||||
parental_sensitivity: 13
|
||||
parental_enabled: true
|
||||
safesearch_enabled: false
|
||||
safebrowsing_enabled: false
|
||||
safebrowsing_cache_size: 1048576
|
||||
safesearch_cache_size: 1048576
|
||||
parental_cache_size: 1048576
|
||||
cache_time: 30
|
||||
rewrites: []
|
||||
blocked_services: []
|
||||
tls:
|
||||
enabled: false
|
||||
server_name: www.example.com
|
||||
force_https: false
|
||||
port_https: 443
|
||||
port_dns_over_tls: 853
|
||||
allow_unencrypted_doh: true
|
||||
certificate_chain: ""
|
||||
private_key: ""
|
||||
certificate_path: ""
|
||||
private_key_path: ""
|
||||
filters:
|
||||
- enabled: true
|
||||
url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
|
||||
name: AdGuard Simplified Domain Names filter
|
||||
id: 1
|
||||
- enabled: false
|
||||
url: https://hosts-file.net/ad_servers.txt
|
||||
name: hpHosts - Ad and Tracking servers only
|
||||
id: 2
|
||||
- enabled: false
|
||||
url: https://adaway.org/hosts.txt
|
||||
name: adaway
|
||||
id: 3
|
||||
user_rules:
|
||||
- ""
|
||||
dhcp:
|
||||
enabled: false
|
||||
interface_name: ""
|
||||
gateway_ip: ""
|
||||
subnet_mask: ""
|
||||
range_start: ""
|
||||
range_end: ""
|
||||
lease_duration: 86400
|
||||
icmp_timeout_msec: 1000
|
||||
clients: []
|
||||
log_file: ""
|
||||
verbose: false
|
||||
schema_version: 5
|
||||
`
|
||||
|
||||
// . Create a configuration file
|
||||
// . Start AGH instance
|
||||
// . Check Web server
|
||||
// . Check DNS server
|
||||
// . Check DNS server with DOH
|
||||
// . Wait until the filters are downloaded
|
||||
// . Stop and cleanup
|
||||
func TestHome(t *testing.T) {
|
||||
// Init new context
|
||||
Context = homeContext{}
|
||||
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
fn := filepath.Join(dir, "AdGuardHome.yaml")
|
||||
|
||||
// Prepare the test config
|
||||
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0644) == nil)
|
||||
fn, _ = filepath.Abs(fn)
|
||||
|
||||
config = configuration{} // the global variable is dirty because of the previous tests run
|
||||
args := options{}
|
||||
args.configFilename = fn
|
||||
args.workDir = dir
|
||||
go run(args)
|
||||
|
||||
var err error
|
||||
var resp *http.Response
|
||||
h := http.Client{}
|
||||
for i := 0; i != 50; i++ {
|
||||
resp, err = h.Get("http://127.0.0.1:3000/")
|
||||
if err == nil && resp.StatusCode != 404 {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
assert.Truef(t, err == nil, "%s", err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
resp, err = h.Get("http://127.0.0.1:3000/control/status")
|
||||
assert.Truef(t, err == nil, "%s", err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// test DNS over UDP
|
||||
r, err := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
|
||||
assert.Nil(t, err)
|
||||
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
|
||||
assert.Nil(t, err)
|
||||
haveIP := len(addrs) != 0
|
||||
assert.True(t, haveIP)
|
||||
|
||||
// test DNS over HTTP without encryption
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
|
||||
buf, err := req.Pack()
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
|
||||
resp, err = http.DefaultClient.Get(requestURL)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.True(t, resp.StatusCode == http.StatusOK)
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(body)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
addrs = nil
|
||||
proxyutil.AppendIPAddrs(&addrs, response.Answer)
|
||||
haveIP = len(addrs) != 0
|
||||
assert.True(t, haveIP)
|
||||
|
||||
for i := 1; ; i++ {
|
||||
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
|
||||
if err == nil && st.Size() != 0 {
|
||||
break
|
||||
}
|
||||
if i == 5 {
|
||||
assert.True(t, false)
|
||||
break
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
}
|
||||
94
internal/home/i18n.go
Normal file
94
internal/home/i18n.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// --------------------
|
||||
// internationalization
|
||||
// --------------------
|
||||
var allowedLanguages = map[string]bool{
|
||||
"be": true,
|
||||
"bg": true,
|
||||
"cs": true,
|
||||
"da": true,
|
||||
"de": true,
|
||||
"en": true,
|
||||
"es": true,
|
||||
"fa": true,
|
||||
"fr": true,
|
||||
"hr": true,
|
||||
"hu": true,
|
||||
"id": true,
|
||||
"it": true,
|
||||
"ja": true,
|
||||
"ko": true,
|
||||
"nl": true,
|
||||
"no": true,
|
||||
"pl": true,
|
||||
"pt-br": true,
|
||||
"pt-pt": true,
|
||||
"ro": true,
|
||||
"ru": true,
|
||||
"si-lk": true,
|
||||
"sk": true,
|
||||
"sl": true,
|
||||
"sr-cs": true,
|
||||
"sv": true,
|
||||
"th": true,
|
||||
"tr": true,
|
||||
"vi": true,
|
||||
"zh-cn": true,
|
||||
"zh-hk": true,
|
||||
"zh-tw": true,
|
||||
}
|
||||
|
||||
func isLanguageAllowed(language string) bool {
|
||||
l := strings.ToLower(language)
|
||||
return allowedLanguages[l]
|
||||
}
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
language := strings.TrimSpace(string(body))
|
||||
if language == "" {
|
||||
errorText := fmt.Sprintf("empty language specified")
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !isLanguageAllowed(language) {
|
||||
errorText := fmt.Sprintf("unknown language specified: %s", language)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
config.Language = language
|
||||
onConfigModified()
|
||||
returnOK(w)
|
||||
}
|
||||
44
internal/home/memory.go
Normal file
44
internal/home/memory.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
|
||||
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
|
||||
// See this for the details on the performance hits & gains:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
|
||||
func memoryUsage(args options) {
|
||||
if args.disableMemoryOptimization {
|
||||
log.Info("Memory optimization is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Makes Go allocate heap at a slower pace
|
||||
// By default we keep it at 50%
|
||||
debug.SetGCPercent(50)
|
||||
|
||||
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
|
||||
// instead of MADV_FREE on Linux when returning memory to the
|
||||
// kernel. This is less efficient, but causes RSS numbers to drop
|
||||
// more quickly.
|
||||
_ = os.Setenv("GODEBUG", "madvdontneed=1")
|
||||
|
||||
// periodically call "debug.FreeOSMemory" so
|
||||
// that the OS could reclaim the free memory
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
t.Second()
|
||||
log.Debug("Free OS memory")
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
113
internal/home/mobileconfig.go
Normal file
113
internal/home/mobileconfig.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
type DNSSettings struct {
|
||||
DNSProtocol string
|
||||
ServerURL string `plist:",omitempty"`
|
||||
ServerName string `plist:",omitempty"`
|
||||
}
|
||||
|
||||
type PayloadContent struct {
|
||||
Name string
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadVersion int
|
||||
DNSSettings DNSSettings
|
||||
}
|
||||
|
||||
type MobileConfig struct {
|
||||
PayloadContent []PayloadContent
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadRemovalDisallowed bool
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadVersion int
|
||||
}
|
||||
|
||||
func genUUIDv4() string {
|
||||
return uuid.NewV4().String()
|
||||
}
|
||||
|
||||
const (
|
||||
dnsProtoHTTPS = "HTTPS"
|
||||
dnsProtoTLS = "TLS"
|
||||
)
|
||||
|
||||
func getMobileConfig(d DNSSettings) ([]byte, error) {
|
||||
var name string
|
||||
switch d.DNSProtocol {
|
||||
case dnsProtoHTTPS:
|
||||
name = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
case dnsProtoTLS:
|
||||
name = fmt.Sprintf("%s DoT", d.ServerName)
|
||||
default:
|
||||
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
|
||||
}
|
||||
|
||||
data := MobileConfig{
|
||||
PayloadContent: []PayloadContent{{
|
||||
Name: name,
|
||||
PayloadDescription: "Configures device to use AdGuard Home",
|
||||
PayloadDisplayName: name,
|
||||
PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()),
|
||||
PayloadType: "com.apple.dnsSettings.managed",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadVersion: 1,
|
||||
DNSSettings: d,
|
||||
}},
|
||||
PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems",
|
||||
PayloadDisplayName: name,
|
||||
PayloadIdentifier: genUUIDv4(),
|
||||
PayloadRemovalDisallowed: false,
|
||||
PayloadType: "Configuration",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadVersion: 1,
|
||||
}
|
||||
|
||||
return plist.MarshalIndent(data, plist.XMLFormat, "\t")
|
||||
}
|
||||
|
||||
func handleMobileConfig(w http.ResponseWriter, d DNSSettings) {
|
||||
mobileconfig, err := getMobileConfig(d)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "plist.MarshalIndent: %s", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
_, _ = w.Write(mobileconfig)
|
||||
}
|
||||
|
||||
func handleMobileConfigDoh(w http.ResponseWriter, r *http.Request) {
|
||||
handleMobileConfig(w, DNSSettings{
|
||||
DNSProtocol: dnsProtoHTTPS,
|
||||
ServerURL: fmt.Sprintf("https://%s/dns-query", r.Host),
|
||||
})
|
||||
}
|
||||
|
||||
func handleMobileConfigDot(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
|
||||
var host string
|
||||
host, _, err = net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "getting host: %s", err)
|
||||
}
|
||||
|
||||
handleMobileConfig(w, DNSSettings{
|
||||
DNSProtocol: dnsProtoTLS,
|
||||
ServerName: host,
|
||||
})
|
||||
}
|
||||
33
internal/home/mobileconfig_test.go
Normal file
33
internal/home/mobileconfig_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
func TestHandleMobileConfigDot(t *testing.T) {
|
||||
var err error
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handleMobileConfigDot(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var mc MobileConfig
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Equal(t, 1, len(mc.PayloadContent)) {
|
||||
assert.Equal(t, "example.com DoT", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.com DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.com", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
}
|
||||
}
|
||||
325
internal/home/options.go
Normal file
325
internal/home/options.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// options passed from command-line arguments
|
||||
type options struct {
|
||||
verbose bool // is verbose logging enabled
|
||||
configFilename string // path to the config file
|
||||
workDir string // path to the working directory where we will store the filters data and the querylog
|
||||
bindHost string // host address to bind HTTP server on
|
||||
bindPort int // port to serve HTTP pages on
|
||||
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
pidFile string // File name to save PID to
|
||||
checkConfig bool // Check configuration and exit
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
|
||||
// service control action (see service.ControlAction array + "status" command)
|
||||
serviceControlAction string
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
|
||||
// disableMemoryOptimization - disables memory optimization hacks
|
||||
// see memoryUsage() function for the details
|
||||
disableMemoryOptimization bool
|
||||
|
||||
glinetMode bool // Activate GL-Inet compatibility mode
|
||||
}
|
||||
|
||||
// functions used for their side-effects
|
||||
type effect func() error
|
||||
|
||||
type arg struct {
|
||||
description string // a short, English description of the argument
|
||||
longName string // the name of the argument used after '--'
|
||||
shortName string // the name of the argument used after '-'
|
||||
|
||||
// only one of updateWithValue, updateNoValue, and effect should be present
|
||||
|
||||
updateWithValue func(o options, v string) (options, error) // the mutator for arguments with parameters
|
||||
updateNoValue func(o options) (options, error) // the mutator for arguments without parameters
|
||||
effect func(o options, exec string) (f effect, err error) // the side-effect closure generator
|
||||
|
||||
serialize func(o options) []string // the re-serialization function back to arguments (return nil for omit)
|
||||
}
|
||||
|
||||
// {type}SliceOrNil functions check their parameter of type {type}
|
||||
// against its zero value and return nil if the parameter value is
|
||||
// zero otherwise they return a string slice of the parameter
|
||||
|
||||
func stringSliceOrNil(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
func intSliceOrNil(i int) []string {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
return []string{strconv.Itoa(i)}
|
||||
}
|
||||
|
||||
func boolSliceOrNil(b bool) []string {
|
||||
if b {
|
||||
return []string{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var args []arg
|
||||
|
||||
var configArg = arg{
|
||||
"Path to the config file",
|
||||
"config", "c",
|
||||
func(o options, v string) (options, error) { o.configFilename = v; return o, nil },
|
||||
nil,
|
||||
nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.configFilename) },
|
||||
}
|
||||
|
||||
var workDirArg = arg{
|
||||
"Path to the working directory",
|
||||
"work-dir", "w",
|
||||
func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.workDir) },
|
||||
}
|
||||
|
||||
var hostArg = arg{
|
||||
"Host address to bind HTTP server on",
|
||||
"host", "h",
|
||||
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.bindHost) },
|
||||
}
|
||||
|
||||
var portArg = arg{
|
||||
"Port to serve HTTP pages on",
|
||||
"port", "p",
|
||||
func(o options, v string) (options, error) {
|
||||
var err error
|
||||
var p int
|
||||
minPort, maxPort := 0, 1<<16-1
|
||||
if p, err = strconv.Atoi(v); err != nil {
|
||||
err = fmt.Errorf("port '%s' is not a number", v)
|
||||
} else if p < minPort || p > maxPort {
|
||||
err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
|
||||
} else {
|
||||
o.bindPort = p
|
||||
}
|
||||
return o, err
|
||||
}, nil, nil,
|
||||
func(o options) []string { return intSliceOrNil(o.bindPort) },
|
||||
}
|
||||
|
||||
var serviceArg = arg{
|
||||
"Service control action: status, install, uninstall, start, stop, restart, reload (configuration)",
|
||||
"service", "s",
|
||||
func(o options, v string) (options, error) {
|
||||
o.serviceControlAction = v
|
||||
return o, nil
|
||||
}, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.serviceControlAction) },
|
||||
}
|
||||
|
||||
var logfileArg = arg{
|
||||
"Path to log file. If empty: write to stdout; if 'syslog': write to system log",
|
||||
"logfile", "l",
|
||||
func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.logFile) },
|
||||
}
|
||||
|
||||
var pidfileArg = arg{
|
||||
"Path to a file where PID is stored",
|
||||
"pidfile", "",
|
||||
func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.pidFile) },
|
||||
}
|
||||
|
||||
var checkConfigArg = arg{
|
||||
"Check configuration and exit",
|
||||
"check-config", "",
|
||||
nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.checkConfig) },
|
||||
}
|
||||
|
||||
var noCheckUpdateArg = arg{
|
||||
"Don't check for updates",
|
||||
"no-check-update", "",
|
||||
nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.disableUpdate) },
|
||||
}
|
||||
|
||||
var disableMemoryOptimizationArg = arg{
|
||||
"Disable memory optimization",
|
||||
"no-mem-optimization", "",
|
||||
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
|
||||
}
|
||||
|
||||
var verboseArg = arg{
|
||||
"Enable verbose output",
|
||||
"verbose", "v",
|
||||
nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.verbose) },
|
||||
}
|
||||
|
||||
var glinetArg = arg{
|
||||
"Run in GL-Inet compatibility mode",
|
||||
"glinet", "",
|
||||
nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.glinetMode) },
|
||||
}
|
||||
|
||||
var versionArg = arg{
|
||||
"Show the version and exit",
|
||||
"version", "",
|
||||
nil, nil, func(o options, exec string) (effect, error) {
|
||||
return func() error { fmt.Println(version()); os.Exit(0); return nil }, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var helpArg = arg{
|
||||
"Print this help",
|
||||
"help", "",
|
||||
nil, nil, func(o options, exec string) (effect, error) {
|
||||
return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
func init() {
|
||||
args = []arg{
|
||||
configArg,
|
||||
workDirArg,
|
||||
hostArg,
|
||||
portArg,
|
||||
serviceArg,
|
||||
logfileArg,
|
||||
pidfileArg,
|
||||
checkConfigArg,
|
||||
noCheckUpdateArg,
|
||||
disableMemoryOptimizationArg,
|
||||
verboseArg,
|
||||
glinetArg,
|
||||
versionArg,
|
||||
helpArg,
|
||||
}
|
||||
}
|
||||
|
||||
func getUsageLines(exec string, args []arg) []string {
|
||||
usage := []string{
|
||||
"Usage:",
|
||||
"",
|
||||
fmt.Sprintf("%s [options]", exec),
|
||||
"",
|
||||
"Options:",
|
||||
}
|
||||
for _, arg := range args {
|
||||
val := ""
|
||||
if arg.updateWithValue != nil {
|
||||
val = " VALUE"
|
||||
}
|
||||
if arg.shortName != "" {
|
||||
usage = append(usage, fmt.Sprintf(" -%s, %-30s %s",
|
||||
arg.shortName,
|
||||
"--"+arg.longName+val,
|
||||
arg.description))
|
||||
} else {
|
||||
usage = append(usage, fmt.Sprintf(" %-34s %s",
|
||||
"--"+arg.longName+val,
|
||||
arg.description))
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func printHelp(exec string) error {
|
||||
for _, line := range getUsageLines(exec, args) {
|
||||
_, err := fmt.Println(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func argMatches(a arg, v string) bool {
|
||||
return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName)
|
||||
}
|
||||
|
||||
func parse(exec string, ss []string) (o options, f effect, err error) {
|
||||
for i := 0; i < len(ss); i++ {
|
||||
v := ss[i]
|
||||
knownParam := false
|
||||
for _, arg := range args {
|
||||
if argMatches(arg, v) {
|
||||
if arg.updateWithValue != nil {
|
||||
if i+1 >= len(ss) {
|
||||
return o, f, fmt.Errorf("got %s without argument", v)
|
||||
}
|
||||
i++
|
||||
o, err = arg.updateWithValue(o, ss[i])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if arg.updateNoValue != nil {
|
||||
o, err = arg.updateNoValue(o)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if arg.effect != nil {
|
||||
var eff effect
|
||||
eff, err = arg.effect(o, exec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if eff != nil {
|
||||
prevf := f
|
||||
f = func() error {
|
||||
var err error
|
||||
if prevf != nil {
|
||||
err = prevf()
|
||||
}
|
||||
if err == nil {
|
||||
err = eff()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
knownParam = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !knownParam {
|
||||
return o, f, fmt.Errorf("unknown option %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func shortestFlag(a arg) string {
|
||||
if a.shortName != "" {
|
||||
return "-" + a.shortName
|
||||
}
|
||||
return "--" + a.longName
|
||||
}
|
||||
|
||||
func serialize(o options) []string {
|
||||
ss := []string{}
|
||||
for _, arg := range args {
|
||||
s := arg.serialize(o)
|
||||
if s != nil {
|
||||
ss = append(ss, append([]string{shortestFlag(arg)}, s...)...)
|
||||
}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
251
internal/home/options_test.go
Normal file
251
internal/home/options_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testParseOk(t *testing.T, ss ...string) options {
|
||||
o, _, err := parse("", ss)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func testParseErr(t *testing.T, descr string, ss ...string) {
|
||||
_, _, err := parse("", ss)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error because %s but no error returned", descr)
|
||||
}
|
||||
}
|
||||
|
||||
func testParseParamMissing(t *testing.T, param string) {
|
||||
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
|
||||
}
|
||||
|
||||
func TestParseVerbose(t *testing.T) {
|
||||
if testParseOk(t).verbose {
|
||||
t.Fatal("empty is not verbose")
|
||||
}
|
||||
if !testParseOk(t, "-v").verbose {
|
||||
t.Fatal("-v is verbose")
|
||||
}
|
||||
if !testParseOk(t, "--verbose").verbose {
|
||||
t.Fatal("--verbose is verbose")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigFilename(t *testing.T) {
|
||||
if testParseOk(t).configFilename != "" {
|
||||
t.Fatal("empty is no config filename")
|
||||
}
|
||||
if testParseOk(t, "-c", "path").configFilename != "path" {
|
||||
t.Fatal("-c is config filename")
|
||||
}
|
||||
testParseParamMissing(t, "-c")
|
||||
if testParseOk(t, "--config", "path").configFilename != "path" {
|
||||
t.Fatal("--configFilename is config filename")
|
||||
}
|
||||
testParseParamMissing(t, "--config")
|
||||
}
|
||||
|
||||
func TestParseWorkDir(t *testing.T) {
|
||||
if testParseOk(t).workDir != "" {
|
||||
t.Fatal("empty is no work dir")
|
||||
}
|
||||
if testParseOk(t, "-w", "path").workDir != "path" {
|
||||
t.Fatal("-w is work dir")
|
||||
}
|
||||
testParseParamMissing(t, "-w")
|
||||
if testParseOk(t, "--work-dir", "path").workDir != "path" {
|
||||
t.Fatal("--work-dir is work dir")
|
||||
}
|
||||
testParseParamMissing(t, "--work-dir")
|
||||
}
|
||||
|
||||
func TestParseBindHost(t *testing.T) {
|
||||
if testParseOk(t).bindHost != "" {
|
||||
t.Fatal("empty is no host")
|
||||
}
|
||||
if testParseOk(t, "-h", "addr").bindHost != "addr" {
|
||||
t.Fatal("-h is host")
|
||||
}
|
||||
testParseParamMissing(t, "-h")
|
||||
if testParseOk(t, "--host", "addr").bindHost != "addr" {
|
||||
t.Fatal("--host is host")
|
||||
}
|
||||
testParseParamMissing(t, "--host")
|
||||
}
|
||||
|
||||
func TestParseBindPort(t *testing.T) {
|
||||
if testParseOk(t).bindPort != 0 {
|
||||
t.Fatal("empty is port 0")
|
||||
}
|
||||
if testParseOk(t, "-p", "65535").bindPort != 65535 {
|
||||
t.Fatal("-p is port")
|
||||
}
|
||||
testParseParamMissing(t, "-p")
|
||||
if testParseOk(t, "--port", "65535").bindPort != 65535 {
|
||||
t.Fatal("--port is port")
|
||||
}
|
||||
testParseParamMissing(t, "--port")
|
||||
}
|
||||
|
||||
func TestParseBindPortBad(t *testing.T) {
|
||||
testParseErr(t, "not an int", "-p", "x")
|
||||
testParseErr(t, "hex not supported", "-p", "0x100")
|
||||
testParseErr(t, "port negative", "-p", "-1")
|
||||
testParseErr(t, "port too high", "-p", "65536")
|
||||
testParseErr(t, "port too high", "-p", "4294967297") // 2^32 + 1
|
||||
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
|
||||
}
|
||||
|
||||
func TestParseLogfile(t *testing.T) {
|
||||
if testParseOk(t).logFile != "" {
|
||||
t.Fatal("empty is no log file")
|
||||
}
|
||||
if testParseOk(t, "-l", "path").logFile != "path" {
|
||||
t.Fatal("-l is log file")
|
||||
}
|
||||
if testParseOk(t, "--logfile", "path").logFile != "path" {
|
||||
t.Fatal("--logfile is log file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePidfile(t *testing.T) {
|
||||
if testParseOk(t).pidFile != "" {
|
||||
t.Fatal("empty is no pid file")
|
||||
}
|
||||
if testParseOk(t, "--pidfile", "path").pidFile != "path" {
|
||||
t.Fatal("--pidfile is pid file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCheckConfig(t *testing.T) {
|
||||
if testParseOk(t).checkConfig {
|
||||
t.Fatal("empty is not check config")
|
||||
}
|
||||
if !testParseOk(t, "--check-config").checkConfig {
|
||||
t.Fatal("--check-config is check config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDisableUpdate(t *testing.T) {
|
||||
if testParseOk(t).disableUpdate {
|
||||
t.Fatal("empty is not disable update")
|
||||
}
|
||||
if !testParseOk(t, "--no-check-update").disableUpdate {
|
||||
t.Fatal("--no-check-update is disable update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||
if testParseOk(t).disableMemoryOptimization {
|
||||
t.Fatal("empty is not disable update")
|
||||
}
|
||||
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
|
||||
t.Fatal("--no-mem-optimization is disable update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseService(t *testing.T) {
|
||||
if testParseOk(t).serviceControlAction != "" {
|
||||
t.Fatal("empty is no service command")
|
||||
}
|
||||
if testParseOk(t, "-s", "command").serviceControlAction != "command" {
|
||||
t.Fatal("-s is service command")
|
||||
}
|
||||
if testParseOk(t, "--service", "command").serviceControlAction != "command" {
|
||||
t.Fatal("--service is service command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGLInet(t *testing.T) {
|
||||
if testParseOk(t).glinetMode {
|
||||
t.Fatal("empty is not GL-Inet mode")
|
||||
}
|
||||
if !testParseOk(t, "--glinet").glinetMode {
|
||||
t.Fatal("--glinet is GL-Inet mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUnknown(t *testing.T) {
|
||||
testParseErr(t, "unknown word", "x")
|
||||
testParseErr(t, "unknown short", "-x")
|
||||
testParseErr(t, "unknown long", "--x")
|
||||
testParseErr(t, "unknown triple", "---x")
|
||||
testParseErr(t, "unknown plus", "+x")
|
||||
testParseErr(t, "unknown dash", "-")
|
||||
}
|
||||
|
||||
func testSerialize(t *testing.T, o options, ss ...string) {
|
||||
result := serialize(o)
|
||||
if len(result) != len(ss) {
|
||||
t.Fatalf("expected %s but got %s", ss, result)
|
||||
}
|
||||
for i, r := range result {
|
||||
if r != ss[i] {
|
||||
t.Fatalf("expected %s but got %s", ss, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeEmpty(t *testing.T) {
|
||||
testSerialize(t, options{})
|
||||
}
|
||||
|
||||
func TestSerializeConfigFilename(t *testing.T) {
|
||||
testSerialize(t, options{configFilename: "path"}, "-c", "path")
|
||||
}
|
||||
|
||||
func TestSerializeWorkDir(t *testing.T) {
|
||||
testSerialize(t, options{workDir: "path"}, "-w", "path")
|
||||
}
|
||||
|
||||
func TestSerializeBindHost(t *testing.T) {
|
||||
testSerialize(t, options{bindHost: "addr"}, "-h", "addr")
|
||||
}
|
||||
|
||||
func TestSerializeBindPort(t *testing.T) {
|
||||
testSerialize(t, options{bindPort: 666}, "-p", "666")
|
||||
}
|
||||
|
||||
func TestSerializeLogfile(t *testing.T) {
|
||||
testSerialize(t, options{logFile: "path"}, "-l", "path")
|
||||
}
|
||||
|
||||
func TestSerializePidfile(t *testing.T) {
|
||||
testSerialize(t, options{pidFile: "path"}, "--pidfile", "path")
|
||||
}
|
||||
|
||||
func TestSerializeCheckConfig(t *testing.T) {
|
||||
testSerialize(t, options{checkConfig: true}, "--check-config")
|
||||
}
|
||||
|
||||
func TestSerializeDisableUpdate(t *testing.T) {
|
||||
testSerialize(t, options{disableUpdate: true}, "--no-check-update")
|
||||
}
|
||||
|
||||
func TestSerializeService(t *testing.T) {
|
||||
testSerialize(t, options{serviceControlAction: "run"}, "-s", "run")
|
||||
}
|
||||
|
||||
func TestSerializeGLInet(t *testing.T) {
|
||||
testSerialize(t, options{glinetMode: true}, "--glinet")
|
||||
}
|
||||
|
||||
func TestSerializeDisableMemoryOptimization(t *testing.T) {
|
||||
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
|
||||
}
|
||||
|
||||
func TestSerializeMultiple(t *testing.T) {
|
||||
testSerialize(t, options{
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
disableMemoryOptimization: true,
|
||||
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
|
||||
}
|
||||
129
internal/home/rdns.go
Normal file
129
internal/home/rdns.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RDNS - module context
|
||||
type RDNS struct {
|
||||
dnsServer *dnsforward.Server
|
||||
clients *clientsContainer
|
||||
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
|
||||
|
||||
// Contains IP addresses of clients to be resolved by rDNS
|
||||
// If IP address is resolved, it stays here while it's inside Clients.
|
||||
// If it's removed from Clients, this IP address will be resolved once again.
|
||||
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
|
||||
ipAddrs cache.Cache
|
||||
}
|
||||
|
||||
// InitRDNS - create module context
|
||||
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
|
||||
r := RDNS{}
|
||||
r.dnsServer = dnsServer
|
||||
r.clients = clients
|
||||
|
||||
cconf := cache.Config{}
|
||||
cconf.EnableLRU = true
|
||||
cconf.MaxCount = 10000
|
||||
r.ipAddrs = cache.New(cconf)
|
||||
|
||||
r.ipChannel = make(chan string, 256)
|
||||
go r.workerLoop()
|
||||
return &r
|
||||
}
|
||||
|
||||
// Begin - add IP address to rDNS queue
|
||||
func (r *RDNS) Begin(ip string) {
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := r.ipAddrs.Get([]byte(ip))
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
return
|
||||
}
|
||||
// TTL expired
|
||||
}
|
||||
expire = make([]byte, 8)
|
||||
const ttl = 1 * 60 * 60
|
||||
binary.BigEndian.PutUint64(expire, now+ttl)
|
||||
_ = r.ipAddrs.Set([]byte(ip), expire)
|
||||
|
||||
if r.clients.Exists(ip, ClientSourceRDNS) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("rDNS: adding %s", ip)
|
||||
select {
|
||||
case r.ipChannel <- ip:
|
||||
//
|
||||
default:
|
||||
log.Tracef("rDNS: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Use rDNS to get hostname by IP address
|
||||
func (r *RDNS) resolve(ip string) string {
|
||||
log.Tracef("Resolving host for %s", ip)
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
}
|
||||
var err error
|
||||
req.Question[0].Name, err = dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resp, err := r.dnsServer.Exchange(&req)
|
||||
if err != nil {
|
||||
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||
return ""
|
||||
}
|
||||
if len(resp.Answer) == 0 {
|
||||
log.Debug("No answer for rDNS lookup of %s", ip)
|
||||
return ""
|
||||
}
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
if !ok {
|
||||
log.Debug("not a PTR response for %s", ip)
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Tracef("PTR response for %s: %s", ip, ptr.String())
|
||||
if strings.HasSuffix(ptr.Ptr, ".") {
|
||||
ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1]
|
||||
}
|
||||
|
||||
return ptr.Ptr
|
||||
}
|
||||
|
||||
// Wait for a signal and then synchronously resolve hostname by IP address
|
||||
// Add the hostname:IP pair to "Clients" array
|
||||
func (r *RDNS) workerLoop() {
|
||||
for {
|
||||
var ip string
|
||||
ip = <-r.ipChannel
|
||||
|
||||
host := r.resolve(ip)
|
||||
if len(host) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
21
internal/home/rdns_test.go
Normal file
21
internal/home/rdns_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResolveRDNS(t *testing.T) {
|
||||
dns := &dnsforward.Server{}
|
||||
conf := &dnsforward.ServerConfig{}
|
||||
conf.UpstreamDNS = []string{"8.8.8.8"}
|
||||
err := dns.Prepare(conf)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
|
||||
clients := &clientsContainer{}
|
||||
rdns := InitRDNS(dns, clients)
|
||||
r := rdns.resolve("1.1.1.1")
|
||||
assert.True(t, r == "one.one.one.one", "%s", r)
|
||||
}
|
||||
517
internal/home/service.go
Normal file
517
internal/home/service.go
Normal file
@@ -0,0 +1,517 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
|
||||
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
|
||||
serviceName = "AdGuardHome"
|
||||
serviceDisplayName = "AdGuard Home service"
|
||||
serviceDescription = "AdGuard Home: Network-level blocker"
|
||||
)
|
||||
|
||||
// Represents the program that will be launched by a service or daemon
|
||||
type program struct {
|
||||
opts options
|
||||
}
|
||||
|
||||
// Start should quickly start the program
|
||||
func (p *program) Start(s service.Service) error {
|
||||
// Start should not block. Do the actual work async.
|
||||
args := p.opts
|
||||
args.runningAsService = true
|
||||
go run(args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the program
|
||||
func (p *program) Stop(s service.Service) error {
|
||||
// Stop should not block. Return with a few seconds.
|
||||
if Context.appSignalChannel == nil {
|
||||
os.Exit(0)
|
||||
}
|
||||
Context.appSignalChannel <- syscall.SIGINT
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the service's status
|
||||
// Note: on OpenWrt 'service' utility may not exist - we use our service script directly in this case.
|
||||
func svcStatus(s service.Service) (service.Status, error) {
|
||||
status, err := s.Status()
|
||||
if err != nil && service.Platform() == "unix-systemv" {
|
||||
code, err := runInitdCommand("status")
|
||||
if err != nil {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
if code != 0 {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
|
||||
// Perform an action on the service
|
||||
// Note: on OpenWrt 'service' utility may not exist - we use our service script directly in this case.
|
||||
func svcAction(s service.Service, action string) error {
|
||||
err := service.Control(s, action)
|
||||
if err != nil && service.Platform() == "unix-systemv" &&
|
||||
(action == "start" || action == "stop" || action == "restart") {
|
||||
_, err := runInitdCommand(action)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Send SIGHUP to a process with ID taken from our pid-file
|
||||
// If pid-file doesn't exist, find our PID using 'ps' command
|
||||
func sendSigReload() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Error("Not implemented on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
pidfile := fmt.Sprintf("/var/run/%s.pid", serviceName)
|
||||
data, err := ioutil.ReadFile(pidfile)
|
||||
if os.IsNotExist(err) {
|
||||
code, psdata, err := util.RunCommand("ps", "-C", serviceName, "-o", "pid=")
|
||||
if err != nil || code != 0 {
|
||||
log.Error("Can't find AdGuardHome process: %s code:%d", err, code)
|
||||
return
|
||||
}
|
||||
data = []byte(psdata)
|
||||
|
||||
} else if err != nil {
|
||||
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(data), "\n", 2)
|
||||
if len(parts) == 0 {
|
||||
log.Error("Can't read PID file %s: bad value", pidfile)
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
log.Error("Can't read PID file %s: %s", pidfile, err)
|
||||
return
|
||||
}
|
||||
err = util.SendProcessSignal(pid, syscall.SIGHUP)
|
||||
if err != nil {
|
||||
log.Error("Can't send signal to PID %d: %s", pid, err)
|
||||
return
|
||||
}
|
||||
log.Debug("Sent signal to PID %d", pid)
|
||||
}
|
||||
|
||||
// handleServiceControlAction one of the possible control actions:
|
||||
// install -- installs a service/daemon
|
||||
// uninstall -- uninstalls it
|
||||
// status -- prints the service status
|
||||
// start -- starts the previously installed service
|
||||
// stop -- stops the previously installed service
|
||||
// restart - restarts the previously installed service
|
||||
// run - this is a special command that is not supposed to be used directly
|
||||
// it is specified when we register a service, and it indicates to the app
|
||||
// that it is being run as a service/daemon.
|
||||
func handleServiceControlAction(opts options) {
|
||||
action := opts.serviceControlAction
|
||||
log.Printf("Service control action: %s", action)
|
||||
|
||||
if action == "reload" {
|
||||
sendSigReload()
|
||||
return
|
||||
}
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatal("Unable to find the path to the current directory")
|
||||
}
|
||||
runOpts := opts
|
||||
runOpts.serviceControlAction = "run"
|
||||
svcConfig := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
WorkingDirectory: pwd,
|
||||
Arguments: serialize(runOpts),
|
||||
}
|
||||
configureService(svcConfig)
|
||||
prg := &program{runOpts}
|
||||
s, err := service.New(prg, svcConfig)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if action == "status" {
|
||||
handleServiceStatusCommand(s)
|
||||
} else if action == "run" {
|
||||
err = s.Run()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to run service: %s", err)
|
||||
}
|
||||
} else if action == "install" {
|
||||
initConfigFilename(opts)
|
||||
initWorkingDir(opts)
|
||||
handleServiceInstallCommand(s)
|
||||
} else if action == "uninstall" {
|
||||
handleServiceUninstallCommand(s)
|
||||
} else {
|
||||
err = svcAction(s, action)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Action %s has been done successfully on %s", action, service.ChosenSystem().String())
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "status" command
|
||||
func handleServiceStatusCommand(s service.Service) {
|
||||
status, errSt := svcStatus(s)
|
||||
if errSt != nil {
|
||||
log.Fatalf("failed to get service status: %s", errSt)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
log.Printf("Service status is unknown")
|
||||
case service.StatusStopped:
|
||||
log.Printf("Service is stopped")
|
||||
case service.StatusRunning:
|
||||
log.Printf("Service is running")
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "install" command
|
||||
func handleServiceInstallCommand(s service.Service) {
|
||||
err := svcAction(s, "install")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if util.IsOpenWrt() {
|
||||
// On OpenWrt it is important to run enable after the service installation
|
||||
// Otherwise, the service won't start on the system startup
|
||||
_, err := runInitdCommand("enable")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start automatically after install
|
||||
err = svcAction(s, "start")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start the service: %s", err)
|
||||
}
|
||||
log.Printf("Service has been started")
|
||||
|
||||
if detectFirstRun() {
|
||||
log.Printf(`Almost ready!
|
||||
AdGuard Home is successfully installed and will automatically start on boot.
|
||||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.`)
|
||||
printHTTPAddresses("http")
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceStatusCommand handles service "uninstall" command
|
||||
func handleServiceUninstallCommand(s service.Service) {
|
||||
if util.IsOpenWrt() {
|
||||
// On OpenWrt it is important to run disable command first
|
||||
// as it will remove the symlink
|
||||
_, err := runInitdCommand("disable")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := svcAction(s, "uninstall")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// Removing log files on cleanup and ignore errors
|
||||
err := os.Remove(launchdStdoutPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("cannot remove %s", launchdStdoutPath)
|
||||
}
|
||||
err = os.Remove(launchdStderrPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("cannot remove %s", launchdStderrPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// configureService defines additional settings of the service
|
||||
func configureService(c *service.Config) {
|
||||
c.Option = service.KeyValue{}
|
||||
|
||||
// OS X
|
||||
// Redefines the launchd config file template
|
||||
// The purpose is to enable stdout/stderr redirect by default
|
||||
c.Option["LaunchdConfig"] = launchdConfig
|
||||
// This key is used to start the job as soon as it has been loaded. For daemons this means execution at boot time, for agents execution at login.
|
||||
c.Option["RunAtLoad"] = true
|
||||
|
||||
// POSIX
|
||||
// Redirect StdErr & StdOut to files.
|
||||
c.Option["LogOutput"] = true
|
||||
|
||||
// Use modified service file templates
|
||||
c.Option["SystemdScript"] = systemdScript
|
||||
c.Option["SysvScript"] = sysvScript
|
||||
|
||||
// On OpenWrt we're using a different type of sysvScript
|
||||
if util.IsOpenWrt() {
|
||||
c.Option["SysvScript"] = openWrtScript
|
||||
} else if util.IsFreeBSD() {
|
||||
c.Option["SysvScript"] = freeBSDScript
|
||||
}
|
||||
}
|
||||
|
||||
// runInitdCommand runs init.d service command
|
||||
// returns command code or error if any
|
||||
func runInitdCommand(action string) (int, error) {
|
||||
confPath := "/etc/init.d/" + serviceName
|
||||
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
|
||||
return code, err
|
||||
}
|
||||
|
||||
// Basically the same template as the one defined in github.com/kardianos/service
|
||||
// but with two additional keys - StandardOutPath and StandardErrorPath
|
||||
var launchdConfig = `<?xml version='1.0' encoding='UTF-8'?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple Computer//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd" >
|
||||
<plist version='1.0'>
|
||||
<dict>
|
||||
<key>Label</key><string>{{html .Name}}</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{{html .Path}}</string>
|
||||
{{range .Config.Arguments}}
|
||||
<string>{{html .}}</string>
|
||||
{{end}}
|
||||
</array>
|
||||
{{if .UserName}}<key>UserName</key><string>{{html .UserName}}</string>{{end}}
|
||||
{{if .ChRoot}}<key>RootDirectory</key><string>{{html .ChRoot}}</string>{{end}}
|
||||
{{if .WorkingDirectory}}<key>WorkingDirectory</key><string>{{html .WorkingDirectory}}</string>{{end}}
|
||||
<key>SessionCreate</key><{{bool .SessionCreate}}/>
|
||||
<key>KeepAlive</key><{{bool .KeepAlive}}/>
|
||||
<key>RunAtLoad</key><{{bool .RunAtLoad}}/>
|
||||
<key>Disabled</key><false/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>` + launchdStdoutPath + `</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>` + launchdStderrPath + `</string>
|
||||
</dict>
|
||||
</plist>
|
||||
`
|
||||
|
||||
// Note: we should keep it in sync with the template from service_systemd_linux.go file
|
||||
// Add "After=" setting for systemd service file, because we must be started only after network is online
|
||||
// Set "RestartSec" to 10
|
||||
const systemdScript = `[Unit]
|
||||
Description={{.Description}}
|
||||
ConditionFileIsExecutable={{.Path|cmdEscape}}
|
||||
After=syslog.target network-online.target
|
||||
|
||||
[Service]
|
||||
StartLimitInterval=5
|
||||
StartLimitBurst=10
|
||||
ExecStart={{.Path|cmdEscape}}{{range .Arguments}} {{.|cmd}}{{end}}
|
||||
{{if .ChRoot}}RootDirectory={{.ChRoot|cmd}}{{end}}
|
||||
{{if .WorkingDirectory}}WorkingDirectory={{.WorkingDirectory|cmdEscape}}{{end}}
|
||||
{{if .UserName}}User={{.UserName}}{{end}}
|
||||
{{if .ReloadSignal}}ExecReload=/bin/kill -{{.ReloadSignal}} "$MAINPID"{{end}}
|
||||
{{if .PIDFile}}PIDFile={{.PIDFile|cmd}}{{end}}
|
||||
{{if and .LogOutput .HasOutputFileSupport -}}
|
||||
StandardOutput=file:/var/log/{{.Name}}.out
|
||||
StandardError=file:/var/log/{{.Name}}.err
|
||||
{{- end}}
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
EnvironmentFile=-/etc/sysconfig/{{.Name}}
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
// Note: we should keep it in sync with the template from service_sysv_linux.go file
|
||||
// Use "ps | grep -v grep | grep $(get_pid)" because "ps PID" may not work on OpenWrt
|
||||
const sysvScript = `#!/bin/sh
|
||||
# For RedHat and cousins:
|
||||
# chkconfig: - 99 01
|
||||
# description: {{.Description}}
|
||||
# processname: {{.Path}}
|
||||
|
||||
### BEGIN INIT INFO
|
||||
# Provides: {{.Path}}
|
||||
# Required-Start:
|
||||
# Required-Stop:
|
||||
# Default-Start: 2 3 4 5
|
||||
# Default-Stop: 0 1 6
|
||||
# Short-Description: {{.DisplayName}}
|
||||
# Description: {{.Description}}
|
||||
### END INIT INFO
|
||||
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
|
||||
|
||||
name=$(basename $(readlink -f $0))
|
||||
pid_file="/var/run/$name.pid"
|
||||
stdout_log="/var/log/$name.log"
|
||||
stderr_log="/var/log/$name.err"
|
||||
|
||||
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
|
||||
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -v grep | grep $(get_pid) > /dev/null 2>&1
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
echo "Already started"
|
||||
else
|
||||
echo "Starting $name"
|
||||
{{if .WorkingDirectory}}cd '{{.WorkingDirectory}}'{{end}}
|
||||
$cmd >> "$stdout_log" 2>> "$stderr_log" &
|
||||
echo $! > "$pid_file"
|
||||
if ! is_running; then
|
||||
echo "Unable to start, see $stdout_log and $stderr_log"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
echo -n "Stopping $name.."
|
||||
kill $(get_pid)
|
||||
for i in $(seq 1 10)
|
||||
do
|
||||
if ! is_running; then
|
||||
break
|
||||
fi
|
||||
echo -n "."
|
||||
sleep 1
|
||||
done
|
||||
echo
|
||||
if is_running; then
|
||||
echo "Not stopped; may still be shutting down or shutdown may have failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Not running"
|
||||
fi
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
if is_running; then
|
||||
echo "Unable to stop, will not attempt to start"
|
||||
exit 1
|
||||
fi
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "Running"
|
||||
else
|
||||
echo "Stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
|
||||
// OpenWrt procd init script
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1386
|
||||
const openWrtScript = `#!/bin/sh /etc/rc.common
|
||||
|
||||
USE_PROCD=1
|
||||
|
||||
START=95
|
||||
STOP=01
|
||||
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
|
||||
name="{{.Name}}"
|
||||
pid_file="/var/run/${name}.pid"
|
||||
|
||||
start_service() {
|
||||
echo "Starting ${name}"
|
||||
|
||||
procd_open_instance
|
||||
procd_set_param command ${cmd}
|
||||
procd_set_param respawn # respawn automatically if something died
|
||||
procd_set_param stdout 1 # forward stdout of the command to logd
|
||||
procd_set_param stderr 1 # same for stderr
|
||||
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
|
||||
|
||||
procd_close_instance
|
||||
echo "${name} has been started"
|
||||
}
|
||||
|
||||
stop_service() {
|
||||
echo "Stopping ${name}"
|
||||
}
|
||||
|
||||
EXTRA_COMMANDS="status"
|
||||
EXTRA_HELP=" status Print the service status"
|
||||
|
||||
get_pid() {
|
||||
cat "${pid_file}"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "${pid_file}" ] && ps | grep -v grep | grep $(get_pid) >/dev/null 2>&1
|
||||
}
|
||||
|
||||
status() {
|
||||
if is_running; then
|
||||
echo "Running"
|
||||
else
|
||||
echo "Stopped"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
`
|
||||
const freeBSDScript = `#!/bin/sh
|
||||
# PROVIDE: {{.Name}}
|
||||
# REQUIRE: networking
|
||||
# KEYWORD: shutdown
|
||||
. /etc/rc.subr
|
||||
name="{{.Name}}"
|
||||
{{.Name}}_env="IS_DAEMON=1"
|
||||
{{.Name}}_user="root"
|
||||
pidfile="/var/run/${name}.pid"
|
||||
command="/usr/sbin/daemon"
|
||||
command_args="-P ${pidfile} -r -f {{.WorkingDirectory}}/{{.Name}}"
|
||||
run_rc_command "$1"
|
||||
`
|
||||
547
internal/home/tls.go
Normal file
547
internal/home/tls.go
Normal file
@@ -0,0 +1,547 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
var tlsWebHandlersRegistered = false
|
||||
|
||||
// TLSMod - TLS module object
|
||||
type TLSMod struct {
|
||||
certLastMod time.Time // last modification time of the certificate file
|
||||
conf tlsConfigSettings
|
||||
confLock sync.Mutex
|
||||
status tlsConfigStatus
|
||||
}
|
||||
|
||||
// Create TLS module
|
||||
func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||
t := &TLSMod{}
|
||||
t.conf = conf
|
||||
if t.conf.Enabled {
|
||||
if !t.load() {
|
||||
// Something is not valid - return an empty TLS config
|
||||
return &TLSMod{conf: tlsConfigSettings{
|
||||
Enabled: conf.Enabled,
|
||||
ServerName: conf.ServerName,
|
||||
PortHTTPS: conf.PortHTTPS,
|
||||
PortDNSOverTLS: conf.PortDNSOverTLS,
|
||||
PortDNSOverQUIC: conf.PortDNSOverQUIC,
|
||||
AllowUnencryptedDOH: conf.AllowUnencryptedDOH,
|
||||
}}
|
||||
}
|
||||
t.setCertFileTime()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TLSMod) load() bool {
|
||||
if !tlsLoadConfig(&t.conf, &t.status) {
|
||||
log.Error("failed to load TLS config: %s", t.status.WarningValidation)
|
||||
return false
|
||||
}
|
||||
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
|
||||
if !data.ValidPair {
|
||||
log.Error("failed to validate certificate: %s", data.WarningValidation)
|
||||
return false
|
||||
}
|
||||
t.status = data
|
||||
return true
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (t *TLSMod) Close() {
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
t.confLock.Lock()
|
||||
*conf = t.conf
|
||||
t.confLock.Unlock()
|
||||
}
|
||||
|
||||
func (t *TLSMod) setCertFileTime() {
|
||||
if len(t.conf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
fi, err := os.Stat(t.conf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
return
|
||||
}
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (t *TLSMod) Start() {
|
||||
if !tlsWebHandlersRegistered {
|
||||
tlsWebHandlersRegistered = true
|
||||
t.registerWebHandlers()
|
||||
}
|
||||
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
}
|
||||
|
||||
// Reload - reload certificate file
|
||||
func (t *TLSMod) Reload() {
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
return
|
||||
}
|
||||
if fi.ModTime().UTC().Equal(t.certLastMod) {
|
||||
log.Debug("TLS: certificate file isn't modified")
|
||||
return
|
||||
}
|
||||
log.Debug("TLS: certificate file is modified")
|
||||
|
||||
t.confLock.Lock()
|
||||
r := t.load()
|
||||
t.confLock.Unlock()
|
||||
if !r {
|
||||
return
|
||||
}
|
||||
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
_ = reconfigureDNSServer()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
}
|
||||
|
||||
// Set certificate and private key data
|
||||
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
|
||||
tls.CertificateChainData = []byte(tls.CertificateChain)
|
||||
tls.PrivateKeyData = []byte(tls.PrivateKey)
|
||||
|
||||
var err error
|
||||
if tls.CertificatePath != "" {
|
||||
if tls.CertificateChain != "" {
|
||||
status.WarningValidation = "certificate data and file can't be set together"
|
||||
return false
|
||||
}
|
||||
tls.CertificateChainData, err = ioutil.ReadFile(tls.CertificatePath)
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
return false
|
||||
}
|
||||
status.ValidCert = true
|
||||
}
|
||||
|
||||
if tls.PrivateKeyPath != "" {
|
||||
if tls.PrivateKey != "" {
|
||||
status.WarningValidation = "private key data and file can't be set together"
|
||||
return false
|
||||
}
|
||||
tls.PrivateKeyData, err = ioutil.ReadFile(tls.PrivateKeyPath)
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
return false
|
||||
}
|
||||
status.ValidKey = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type tlsConfigStatus struct {
|
||||
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||
|
||||
// key status
|
||||
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||
|
||||
// is usable? set by validator
|
||||
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||
|
||||
// warnings
|
||||
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type tlsConfig struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
tlsConfigStatus `json:",inline"`
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
t.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettings: t.conf,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
t.confLock.Unlock()
|
||||
marshalTLS(w, data)
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !WebCheckPortAvailable(setts.PortHTTPS) {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
|
||||
return
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if tlsLoadConfig(&setts, &status) {
|
||||
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
|
||||
}
|
||||
|
||||
data := tlsConfig{
|
||||
tlsConfigSettings: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
marshalTLS(w, data)
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !WebCheckPortAvailable(data.PortHTTPS) {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||
return
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if !tlsLoadConfig(&data, &status) {
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettings: data,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
marshalTLS(w, data2)
|
||||
return
|
||||
}
|
||||
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||
restartHTTPS := false
|
||||
t.confLock.Lock()
|
||||
if !reflect.DeepEqual(t.conf, data) {
|
||||
log.Printf("tls config settings have changed, will restart HTTPS server")
|
||||
restartHTTPS = true
|
||||
}
|
||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
||||
t.conf.Enabled = data.Enabled
|
||||
t.conf.ServerName = data.ServerName
|
||||
t.conf.ForceHTTPS = data.ForceHTTPS
|
||||
t.conf.PortHTTPS = data.PortHTTPS
|
||||
t.conf.PortDNSOverTLS = data.PortDNSOverTLS
|
||||
t.conf.PortDNSOverQUIC = data.PortDNSOverQUIC
|
||||
t.conf.CertificateChain = data.CertificateChain
|
||||
t.conf.CertificatePath = data.CertificatePath
|
||||
t.conf.CertificateChainData = data.CertificateChainData
|
||||
t.conf.PrivateKey = data.PrivateKey
|
||||
t.conf.PrivateKeyPath = data.PrivateKeyPath
|
||||
t.conf.PrivateKeyData = data.PrivateKeyData
|
||||
t.status = status
|
||||
t.confLock.Unlock()
|
||||
t.setCertFileTime()
|
||||
onConfigModified()
|
||||
err = reconfigureDNSServer()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettings: data,
|
||||
tlsConfigStatus: t.status,
|
||||
}
|
||||
marshalTLS(w, data2)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(data)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCertChain(data *tlsConfigStatus, certChain string, serverName string) error {
|
||||
log.Tracef("TLS: got certificate: %d bytes", len(certChain))
|
||||
|
||||
// now do a more extended validation
|
||||
var certs []*pem.Block // PEM-encoded certificates
|
||||
var skippedBytes []string // skipped bytes
|
||||
|
||||
pemblock := []byte(certChain)
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
decoded, pemblock = pem.Decode(pemblock)
|
||||
if decoded == nil {
|
||||
break
|
||||
}
|
||||
if decoded.Type == "CERTIFICATE" {
|
||||
certs = append(certs, decoded)
|
||||
} else {
|
||||
// ignore "this result of append is never used" warning
|
||||
// nolint
|
||||
skippedBytes = append(skippedBytes, decoded.Type)
|
||||
}
|
||||
}
|
||||
|
||||
var parsedCerts []*x509.Certificate
|
||||
|
||||
for _, cert := range certs {
|
||||
parsed, err := x509.ParseCertificate(cert.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
parsedCerts = append(parsedCerts, parsed)
|
||||
}
|
||||
|
||||
if len(parsedCerts) == 0 {
|
||||
data.WarningValidation = fmt.Sprintf("You have specified an empty certificate")
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidCert = true
|
||||
|
||||
// spew.Dump(parsedCerts)
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: serverName,
|
||||
Roots: Context.tlsRoots,
|
||||
}
|
||||
|
||||
log.Printf("number of certs - %d", len(parsedCerts))
|
||||
if len(parsedCerts) > 1 {
|
||||
// set up an intermediate
|
||||
pool := x509.NewCertPool()
|
||||
for _, cert := range parsedCerts[1:] {
|
||||
log.Printf("got an intermediate cert")
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
opts.Intermediates = pool
|
||||
}
|
||||
|
||||
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem
|
||||
mainCert := parsedCerts[0]
|
||||
_, err := mainCert.Verify(opts)
|
||||
if err != nil {
|
||||
// let self-signed certs through
|
||||
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err)
|
||||
} else {
|
||||
data.ValidChain = true
|
||||
}
|
||||
// spew.Dump(chains)
|
||||
|
||||
// update status
|
||||
if mainCert != nil {
|
||||
notAfter := mainCert.NotAfter
|
||||
data.Subject = mainCert.Subject.String()
|
||||
data.Issuer = mainCert.Issuer.String()
|
||||
data.NotAfter = notAfter
|
||||
data.NotBefore = mainCert.NotBefore
|
||||
data.DNSNames = mainCert.DNSNames
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePkey(data *tlsConfigStatus, pkey string) error {
|
||||
// now do a more extended validation
|
||||
var key *pem.Block // PEM-encoded certificates
|
||||
var skippedBytes []string // skipped bytes
|
||||
|
||||
// go through all pem blocks, but take first valid pem block and drop the rest
|
||||
pemblock := []byte(pkey)
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
decoded, pemblock = pem.Decode(pemblock)
|
||||
if decoded == nil {
|
||||
break
|
||||
}
|
||||
if decoded.Type == "PRIVATE KEY" || strings.HasSuffix(decoded.Type, " PRIVATE KEY") {
|
||||
key = decoded
|
||||
break
|
||||
} else {
|
||||
// ignore "this result of append is never used"
|
||||
// nolint
|
||||
skippedBytes = append(skippedBytes, decoded.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
data.WarningValidation = "No valid keys were found"
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
// parse the decoded key
|
||||
_, keytype, err := parsePrivateKey(key.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidKey = true
|
||||
data.KeyType = keytype
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process certificate data and its private key.
|
||||
// All parameters are optional.
|
||||
// On error, return partially set object
|
||||
// with 'WarningValidation' field containing error description.
|
||||
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus {
|
||||
var data tlsConfigStatus
|
||||
|
||||
// check only public certificate separately from the key
|
||||
if certChain != "" {
|
||||
if verifyCertChain(&data, certChain, serverName) != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// validate private key (right now the only validation possible is just parsing it)
|
||||
if pkey != "" {
|
||||
if validatePkey(&data, pkey) != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// if both are set, validate both in unison
|
||||
if pkey != "" && certChain != "" {
|
||||
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err)
|
||||
return data
|
||||
}
|
||||
data.ValidPair = true
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
|
||||
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
|
||||
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
|
||||
func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
||||
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
||||
return key, "RSA", nil
|
||||
}
|
||||
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return key, "RSA", nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return key, "ECDSA", nil
|
||||
default:
|
||||
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
|
||||
}
|
||||
}
|
||||
|
||||
if key, err := x509.ParseECPrivateKey(der); err == nil {
|
||||
return key, "ECDSA", nil
|
||||
}
|
||||
|
||||
return nil, "", errors.New("tls: failed to parse private key")
|
||||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
func unmarshalTLS(r *http.Request) (tlsConfigSettings, error) {
|
||||
data := tlsConfigSettings{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
|
||||
}
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
certPEM, err := base64.StdEncoding.DecodeString(data.CertificateChain)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to base64-decode certificate chain")
|
||||
}
|
||||
data.CertificateChain = string(certPEM)
|
||||
if data.CertificatePath != "" {
|
||||
return data, fmt.Errorf("certificate data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
keyPEM, err := base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to base64-decode private key")
|
||||
}
|
||||
|
||||
data.PrivateKey = string(keyPEM)
|
||||
if data.PrivateKeyPath != "" {
|
||||
return data, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func marshalTLS(w http.ResponseWriter, data tlsConfig) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
||||
data.CertificateChain = encoded
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.PrivateKey))
|
||||
data.PrivateKey = encoded
|
||||
}
|
||||
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration
|
||||
func (t *TLSMod) registerWebHandlers() {
|
||||
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
|
||||
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
|
||||
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
|
||||
}
|
||||
437
internal/home/upgrade.go
Normal file
437
internal/home/upgrade.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const currentSchemaVersion = 7 // used for upgrading from old configs to new config
|
||||
|
||||
// Performs necessary upgrade operations if needed
|
||||
func upgradeConfig() error {
|
||||
// read a config file into an interface map, so we can manipulate values without losing any
|
||||
diskConfig := map[string]interface{}{}
|
||||
body, err := readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(body, &diskConfig)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't parse config file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
schemaVersionInterface, ok := diskConfig["schema_version"]
|
||||
log.Tracef("got schema version %v", schemaVersionInterface)
|
||||
if !ok {
|
||||
// no schema version, set it to 0
|
||||
schemaVersionInterface = 0
|
||||
}
|
||||
|
||||
schemaVersion, ok := schemaVersionInterface.(int)
|
||||
if !ok {
|
||||
err = fmt.Errorf("configuration file contains non-integer schema_version, abort")
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
// do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
return upgradeConfigSchema(schemaVersion, &diskConfig)
|
||||
}
|
||||
|
||||
// Upgrade from oldVersion to newVersion
|
||||
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
|
||||
switch oldVersion {
|
||||
case 0:
|
||||
err := upgradeSchema0to1(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 1:
|
||||
err := upgradeSchema1to2(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 2:
|
||||
err := upgradeSchema2to3(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 3:
|
||||
err := upgradeSchema3to4(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 4:
|
||||
err := upgradeSchema4to5(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 5:
|
||||
err := upgradeSchema5to6(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case 6:
|
||||
err := upgradeSchema6to7(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("configuration file contains unknown schema_version, abort")
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
body, err := yaml.Marshal(diskConfig)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate YAML file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
config.fileData = body
|
||||
err = file.SafeWrite(configFile, body)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't save YAML config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// The first schema upgrade:
|
||||
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
|
||||
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
|
||||
err = os.Remove(dnsFilterPath)
|
||||
if err != nil {
|
||||
log.Printf("Cannot remove %s due to %s", dnsFilterPath, err)
|
||||
// not fatal, move on
|
||||
}
|
||||
}
|
||||
|
||||
(*diskConfig)["schema_version"] = 1
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Second schema upgrade:
|
||||
// coredns is now dns in config
|
||||
// delete 'Corefile', since we don't use that anymore
|
||||
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
coreFilePath := filepath.Join(Context.workDir, "Corefile")
|
||||
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
|
||||
err = os.Remove(coreFilePath)
|
||||
if err != nil {
|
||||
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
|
||||
// not fatal, move on
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := (*diskConfig)["dns"]; !ok {
|
||||
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
|
||||
delete((*diskConfig), "coredns")
|
||||
}
|
||||
(*diskConfig)["schema_version"] = 2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Third schema upgrade:
|
||||
// Bootstrap DNS becomes an array
|
||||
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
// Let's read dns configuration from diskConfig
|
||||
dnsConfig, ok := (*diskConfig)["dns"]
|
||||
if !ok {
|
||||
return fmt.Errorf("no DNS configuration in config file")
|
||||
}
|
||||
|
||||
// Convert interface{} to map[string]interface{}
|
||||
newDNSConfig := make(map[string]interface{})
|
||||
|
||||
switch v := dnsConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
for k, v := range v {
|
||||
newDNSConfig[fmt.Sprint(k)] = v
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("DNS configuration is not a map")
|
||||
}
|
||||
|
||||
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
|
||||
if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok {
|
||||
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
|
||||
(newDNSConfig)["bootstrap_dns"] = newBootstrapConfig
|
||||
(*diskConfig)["dns"] = newDNSConfig
|
||||
} else {
|
||||
return fmt.Errorf("no bootstrap DNS in DNS config")
|
||||
}
|
||||
|
||||
// Bump schema version
|
||||
(*diskConfig)["schema_version"] = 3
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add use_global_blocked_services=true setting for existing "clients" array
|
||||
func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
(*diskConfig)["schema_version"] = 4
|
||||
|
||||
clients, ok := (*diskConfig)["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
|
||||
for i := range arr {
|
||||
|
||||
switch c := arr[i].(type) {
|
||||
|
||||
case map[interface{}]interface{}:
|
||||
c["use_global_blocked_services"] = true
|
||||
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Replace "auth_name", "auth_pass" string settings with an array:
|
||||
// users:
|
||||
// - name: "..."
|
||||
// password: "..."
|
||||
// ...
|
||||
func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
(*diskConfig)["schema_version"] = 5
|
||||
|
||||
name, ok := (*diskConfig)["auth_name"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
nameStr, ok := name.(string)
|
||||
if !ok {
|
||||
log.Fatal("Please use double quotes in your user name in \"auth_name\" and restart AdGuardHome")
|
||||
return nil
|
||||
}
|
||||
|
||||
pass, ok := (*diskConfig)["auth_pass"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
passStr, ok := pass.(string)
|
||||
if !ok {
|
||||
log.Fatal("Please use double quotes in your password in \"auth_pass\" and restart AdGuardHome")
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(nameStr) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(passStr), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Fatalf("Can't use password \"%s\": bcrypt.GenerateFromPassword: %s", passStr, err)
|
||||
return nil
|
||||
}
|
||||
u := User{
|
||||
Name: nameStr,
|
||||
PasswordHash: string(hash),
|
||||
}
|
||||
users := []User{u}
|
||||
(*diskConfig)["users"] = users
|
||||
return nil
|
||||
}
|
||||
|
||||
// clients:
|
||||
// ...
|
||||
// ip: 127.0.0.1
|
||||
// mac: ...
|
||||
//
|
||||
// ->
|
||||
//
|
||||
// clients:
|
||||
// ...
|
||||
// ids:
|
||||
// - 127.0.0.1
|
||||
// - ...
|
||||
func upgradeSchema5to6(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", util.FuncName())
|
||||
|
||||
(*diskConfig)["schema_version"] = 6
|
||||
|
||||
clients, ok := (*diskConfig)["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
|
||||
for i := range arr {
|
||||
|
||||
switch c := arr[i].(type) {
|
||||
|
||||
case map[interface{}]interface{}:
|
||||
_ip, ok := c["ip"]
|
||||
ids := []string{}
|
||||
if ok {
|
||||
ip, ok := _ip.(string)
|
||||
if !ok {
|
||||
log.Fatalf("client.ip is not a string: %v", _ip)
|
||||
return nil
|
||||
}
|
||||
if len(ip) != 0 {
|
||||
ids = append(ids, ip)
|
||||
}
|
||||
}
|
||||
|
||||
_mac, ok := c["mac"]
|
||||
if ok {
|
||||
mac, ok := _mac.(string)
|
||||
if !ok {
|
||||
log.Fatalf("client.mac is not a string: %v", _mac)
|
||||
return nil
|
||||
}
|
||||
if len(mac) != 0 {
|
||||
ids = append(ids, mac)
|
||||
}
|
||||
}
|
||||
|
||||
c["ids"] = ids
|
||||
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dhcp:
|
||||
// enabled: false
|
||||
// interface_name: vboxnet0
|
||||
// gateway_ip: 192.168.56.1
|
||||
// ...
|
||||
//
|
||||
// ->
|
||||
//
|
||||
// dhcp:
|
||||
// enabled: false
|
||||
// interface_name: vboxnet0
|
||||
// dhcpv4:
|
||||
// gateway_ip: 192.168.56.1
|
||||
// ...
|
||||
func upgradeSchema6to7(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("Upgrade yaml: 6 to 7")
|
||||
|
||||
(*diskConfig)["schema_version"] = 7
|
||||
|
||||
_dhcp, ok := (*diskConfig)["dhcp"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch dhcp := _dhcp.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
dhcpv4 := map[string]interface{}{}
|
||||
val, ok := dhcp["gateway_ip"].(string)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be a string", "gateway_ip")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["gateway_ip"] = val
|
||||
delete(dhcp, "gateway_ip")
|
||||
|
||||
val, ok = dhcp["subnet_mask"].(string)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be a string", "subnet_mask")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["subnet_mask"] = val
|
||||
delete(dhcp, "subnet_mask")
|
||||
|
||||
val, ok = dhcp["range_start"].(string)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be a string", "range_start")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["range_start"] = val
|
||||
delete(dhcp, "range_start")
|
||||
|
||||
val, ok = dhcp["range_end"].(string)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be a string", "range_end")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["range_end"] = val
|
||||
delete(dhcp, "range_end")
|
||||
|
||||
intVal, ok := dhcp["lease_duration"].(int)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be an integer", "lease_duration")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["lease_duration"] = intVal
|
||||
delete(dhcp, "lease_duration")
|
||||
|
||||
intVal, ok = dhcp["icmp_timeout_msec"].(int)
|
||||
if !ok {
|
||||
log.Fatalf("expecting dhcp.%s to be an integer", "icmp_timeout_msec")
|
||||
return nil
|
||||
}
|
||||
dhcpv4["icmp_timeout_msec"] = intVal
|
||||
delete(dhcp, "icmp_timeout_msec")
|
||||
|
||||
dhcp["dhcpv4"] = dhcpv4
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
230
internal/home/upgrade_test.go
Normal file
230
internal/home/upgrade_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgrade1to2(t *testing.T) {
|
||||
// let's create test config for 1 schema version
|
||||
diskConfig := createTestDiskConfig(1)
|
||||
|
||||
// update config
|
||||
err := upgradeSchema1to2(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// ensure that schema version was bumped
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 2)
|
||||
|
||||
// old coredns entry should be removed
|
||||
_, ok := diskConfig["coredns"]
|
||||
if ok {
|
||||
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// pull out new dns config
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// cast dns configurations to maps and compare them
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
|
||||
|
||||
// exclude dns config and schema version from disk config comparison
|
||||
oldExcludedEntries := []string{"coredns", "schema_version"}
|
||||
newExcludedEntries := []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(1)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
|
||||
}
|
||||
|
||||
func TestUpgrade2to3(t *testing.T) {
|
||||
// let's create test config
|
||||
diskConfig := createTestDiskConfig(2)
|
||||
|
||||
// upgrade schema from 2 to 3
|
||||
err := upgradeSchema2to3(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
|
||||
}
|
||||
|
||||
// check new schema version
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 3)
|
||||
|
||||
// pull out new dns configuration
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No dns config in new configuration")
|
||||
}
|
||||
|
||||
// cast dns configuration to map
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
|
||||
// check if bootstrap DNS becomes an array
|
||||
bootstrapDNS := newDNSConfig["bootstrap_dns"]
|
||||
switch v := bootstrapDNS.(type) {
|
||||
case []string:
|
||||
if len(v) != 1 {
|
||||
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
|
||||
}
|
||||
|
||||
if v[0] != "8.8.8.8:53" {
|
||||
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
|
||||
}
|
||||
|
||||
// exclude bootstrap DNS from DNS configs comparison
|
||||
excludedEntries := []string{"bootstrap_dns"}
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
|
||||
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
|
||||
|
||||
// excluded dns config and schema version from disk config comparison
|
||||
excludedEntries = []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(2)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
|
||||
}
|
||||
|
||||
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
|
||||
newConfig = make(map[string]interface{})
|
||||
switch v := oldConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
for key, value := range v {
|
||||
newConfig[fmt.Sprint(key)] = value
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
newConfig[key] = value
|
||||
}
|
||||
default:
|
||||
t.Fatalf("DNS configuration is not a map")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
|
||||
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
|
||||
for _, k := range oldKey {
|
||||
delete(*oldConfig, k)
|
||||
}
|
||||
for _, k := range newKey {
|
||||
delete(*newConfig, k)
|
||||
}
|
||||
compareConfigs(t, oldConfig, newConfig)
|
||||
}
|
||||
|
||||
// compares configs before and after schema upgrade
|
||||
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
|
||||
if len(*oldConfig) != len(*newConfig) {
|
||||
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
|
||||
}
|
||||
|
||||
// Check old and new entries
|
||||
for k, v := range *newConfig {
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case int:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case []string:
|
||||
for i, line := range value {
|
||||
if len((*oldConfig)[k].([]string)) != len(value) {
|
||||
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
|
||||
}
|
||||
if (*oldConfig)[k].([]string)[i] != line {
|
||||
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
|
||||
}
|
||||
}
|
||||
case bool:
|
||||
if v != (*oldConfig)[k].(bool) {
|
||||
t.Fatalf("wrong boolean value for %s", k)
|
||||
}
|
||||
case []filter:
|
||||
if len((*oldConfig)[k].([]filter)) != len(value) {
|
||||
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
|
||||
}
|
||||
for i, newFilter := range value {
|
||||
oldFilter := (*oldConfig)[k].([]filter)[i]
|
||||
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
|
||||
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
|
||||
}
|
||||
}
|
||||
default:
|
||||
t.Fatalf("uknown data type for %s: %T", k, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
|
||||
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
|
||||
switch v := newSchemaVersion.(type) {
|
||||
case int:
|
||||
if v != schemaVersion {
|
||||
t.Fatalf("Wrong schema version in new config file")
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Schema version is not an integer after update")
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
|
||||
diskConfig = make(map[string]interface{})
|
||||
diskConfig["language"] = "en"
|
||||
diskConfig["filters"] = []filter{
|
||||
{
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
RulesCount: 100,
|
||||
},
|
||||
{
|
||||
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
|
||||
Name: "Germany filter",
|
||||
RulesCount: 200,
|
||||
},
|
||||
}
|
||||
diskConfig["user_rules"] = []string{}
|
||||
diskConfig["schema_version"] = schemaVersion
|
||||
diskConfig["bind_host"] = "0.0.0.0"
|
||||
diskConfig["bind_port"] = 80
|
||||
diskConfig["auth_name"] = "name"
|
||||
diskConfig["auth_pass"] = "pass"
|
||||
dnsConfig := createTestDNSConfig(schemaVersion)
|
||||
if schemaVersion > 1 {
|
||||
diskConfig["dns"] = dnsConfig
|
||||
} else {
|
||||
diskConfig["coredns"] = dnsConfig
|
||||
}
|
||||
return diskConfig
|
||||
}
|
||||
|
||||
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
|
||||
dnsConfig := make(map[interface{}]interface{})
|
||||
dnsConfig["port"] = 53
|
||||
dnsConfig["blocked_response_ttl"] = 10
|
||||
dnsConfig["querylog_enabled"] = true
|
||||
dnsConfig["ratelimit"] = 20
|
||||
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
|
||||
if schemaVersion > 2 {
|
||||
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
|
||||
}
|
||||
dnsConfig["parental_sensitivity"] = 13
|
||||
dnsConfig["ratelimit_whitelist"] = []string{}
|
||||
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
|
||||
dnsConfig["filtering_enabled"] = true
|
||||
dnsConfig["refuse_any"] = true
|
||||
dnsConfig["parental_enabled"] = true
|
||||
dnsConfig["bind_host"] = "0.0.0.0"
|
||||
dnsConfig["protection_enabled"] = true
|
||||
dnsConfig["safesearch_enabled"] = true
|
||||
dnsConfig["safebrowsing_enabled"] = true
|
||||
return dnsConfig
|
||||
}
|
||||
209
internal/home/web.go
Normal file
209
internal/home/web.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
golog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gobuffalo/packr"
|
||||
)
|
||||
|
||||
type WebConfig struct {
|
||||
firstRun bool
|
||||
BindHost string
|
||||
BindPort int
|
||||
PortHTTPS int
|
||||
}
|
||||
|
||||
// HTTPSServer - HTTPS Server
|
||||
type HTTPSServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond
|
||||
condLock sync.Mutex
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
enabled bool
|
||||
cert tls.Certificate
|
||||
}
|
||||
|
||||
// Web - module object
|
||||
type Web struct {
|
||||
conf *WebConfig
|
||||
forceHTTPS bool
|
||||
portHTTPS int
|
||||
httpServer *http.Server // HTTP module
|
||||
httpsServer HTTPSServer // HTTPS module
|
||||
errLogger *golog.Logger
|
||||
}
|
||||
|
||||
// Proxy between Go's "log" and "golibs/log"
|
||||
type logWriter struct {
|
||||
}
|
||||
|
||||
// HTTP server calls this function to log an error
|
||||
func (w *logWriter) Write(p []byte) (int, error) {
|
||||
log.Debug("Web: %s", string(p))
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// CreateWeb - create module
|
||||
func CreateWeb(conf *WebConfig) *Web {
|
||||
log.Info("Initialize web module")
|
||||
|
||||
w := Web{}
|
||||
w.conf = conf
|
||||
|
||||
lw := logWriter{}
|
||||
w.errLogger = golog.New(&lw, "", 0)
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../../build/static")
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if conf.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||
w.registerInstallHandlers()
|
||||
} else {
|
||||
registerControlHandlers()
|
||||
}
|
||||
|
||||
w.httpsServer.cond = sync.NewCond(&w.httpsServer.condLock)
|
||||
return &w
|
||||
}
|
||||
|
||||
// WebCheckPortAvailable - check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
func WebCheckPortAvailable(port int) bool {
|
||||
alreadyRunning := false
|
||||
if Context.web.httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err := util.CheckPortAvailable(config.BindHost, port)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TLSConfigChanged - called when TLS configuration has changed
|
||||
func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||
log.Debug("Web: applying new TLS configuration")
|
||||
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||
web.portHTTPS = tlsConf.PortHTTPS
|
||||
|
||||
enabled := tlsConf.Enabled &&
|
||||
tlsConf.PortHTTPS != 0 &&
|
||||
len(tlsConf.PrivateKeyData) != 0 &&
|
||||
len(tlsConf.CertificateChainData) != 0
|
||||
var cert tls.Certificate
|
||||
var err error
|
||||
if enabled {
|
||||
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.server != nil {
|
||||
_ = web.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
web.httpsServer.enabled = enabled
|
||||
web.httpsServer.cert = cert
|
||||
web.httpsServer.cond.Broadcast()
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
}
|
||||
|
||||
// Start - start serving HTTP requests
|
||||
func (web *Web) Start() {
|
||||
// for https, we have a separate goroutine loop
|
||||
go web.tlsServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
|
||||
web.httpServer = &http.Server{
|
||||
ErrorLog: web.errLogger,
|
||||
Addr: address,
|
||||
}
|
||||
err := web.httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
||||
}
|
||||
}
|
||||
|
||||
// Close - stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func (web *Web) Close() {
|
||||
log.Info("Stopping HTTP server...")
|
||||
web.httpsServer.cond.L.Lock()
|
||||
web.httpsServer.shutdown = true
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
if web.httpsServer.server != nil {
|
||||
_ = web.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
if web.httpServer != nil {
|
||||
_ = web.httpServer.Shutdown(context.TODO())
|
||||
}
|
||||
|
||||
log.Info("Stopped HTTP server")
|
||||
}
|
||||
|
||||
func (web *Web) tlsServerLoop() {
|
||||
for {
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.shutdown {
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
// this mechanism doesn't let us through until all conditions are met
|
||||
for !web.httpsServer.enabled { // sleep until necessary data is supplied
|
||||
web.httpsServer.cond.Wait()
|
||||
if web.httpsServer.shutdown {
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: web.errLogger,
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCiphers,
|
||||
},
|
||||
}
|
||||
|
||||
printHTTPAddresses("https")
|
||||
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
237
internal/home/whois.go
Normal file
237
internal/home/whois.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultServer = "whois.arin.net"
|
||||
defaultPort = "43"
|
||||
maxValueLength = 250
|
||||
whoisTTL = 1 * 60 * 60 // 1 hour
|
||||
)
|
||||
|
||||
// Whois - module context
|
||||
type Whois struct {
|
||||
clients *clientsContainer
|
||||
ipChan chan string
|
||||
timeoutMsec uint
|
||||
|
||||
// Contains IP addresses of clients
|
||||
// An active IP address is resolved once again after it expires.
|
||||
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
|
||||
ipAddrs cache.Cache
|
||||
}
|
||||
|
||||
// Create module context
|
||||
func initWhois(clients *clientsContainer) *Whois {
|
||||
w := Whois{}
|
||||
w.timeoutMsec = 5000
|
||||
w.clients = clients
|
||||
|
||||
cconf := cache.Config{}
|
||||
cconf.EnableLRU = true
|
||||
cconf.MaxCount = 10000
|
||||
w.ipAddrs = cache.New(cconf)
|
||||
|
||||
w.ipChan = make(chan string, 255)
|
||||
go w.workerLoop()
|
||||
return &w
|
||||
}
|
||||
|
||||
// If the value is too large - cut it and append "..."
|
||||
func trimValue(s string) string {
|
||||
if len(s) <= maxValueLength {
|
||||
return s
|
||||
}
|
||||
return s[:maxValueLength-3] + "..."
|
||||
}
|
||||
|
||||
// Parse plain-text data from the response
|
||||
func whoisParse(data string) map[string]string {
|
||||
m := map[string]string{}
|
||||
descr := ""
|
||||
netname := ""
|
||||
for len(data) != 0 {
|
||||
ln := util.SplitNext(&data, '\n')
|
||||
if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' {
|
||||
continue
|
||||
}
|
||||
|
||||
kv := strings.SplitN(ln, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(kv[0])
|
||||
k = strings.ToLower(k)
|
||||
v := strings.TrimSpace(kv[1])
|
||||
|
||||
switch k {
|
||||
case "org-name":
|
||||
m["orgname"] = trimValue(v)
|
||||
case "orgname":
|
||||
fallthrough
|
||||
case "city":
|
||||
fallthrough
|
||||
case "country":
|
||||
m[k] = trimValue(v)
|
||||
|
||||
case "descr":
|
||||
if len(descr) == 0 {
|
||||
descr = v
|
||||
}
|
||||
case "netname":
|
||||
netname = v
|
||||
|
||||
case "whois": // "whois: whois.arin.net"
|
||||
m["whois"] = v
|
||||
|
||||
case "referralserver": // "ReferralServer: whois://whois.ripe.net"
|
||||
if strings.HasPrefix(v, "whois://") {
|
||||
m["whois"] = v[len("whois://"):]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// descr or netname -> orgname
|
||||
_, ok := m["orgname"]
|
||||
if !ok && len(descr) != 0 {
|
||||
m["orgname"] = trimValue(descr)
|
||||
} else if !ok && len(netname) != 0 {
|
||||
m["orgname"] = trimValue(netname)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *Whois) query(target string, serverAddr string) (string, error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
|
||||
_, err = conn.Write([]byte(target + "\r\n"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(conn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Query WHOIS servers (handle redirects)
|
||||
func (w *Whois) queryAll(target string) (string, error) {
|
||||
server := net.JoinHostPort(defaultServer, defaultPort)
|
||||
const maxRedirects = 5
|
||||
for i := 0; i != maxRedirects; i++ {
|
||||
resp, err := w.query(target, server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
log.Debug("Whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
|
||||
|
||||
m := whoisParse(resp)
|
||||
redir, ok := m["whois"]
|
||||
if !ok {
|
||||
return resp, nil
|
||||
}
|
||||
redir = strings.ToLower(redir)
|
||||
|
||||
_, _, err = net.SplitHostPort(redir)
|
||||
if err != nil {
|
||||
server = net.JoinHostPort(redir, defaultPort)
|
||||
} else {
|
||||
server = redir
|
||||
}
|
||||
|
||||
log.Debug("Whois: redirected to %s IP:%s", redir, target)
|
||||
}
|
||||
return "", fmt.Errorf("Whois: redirect loop")
|
||||
}
|
||||
|
||||
// Request WHOIS information
|
||||
func (w *Whois) process(ip string) [][]string {
|
||||
data := [][]string{}
|
||||
resp, err := w.queryAll(ip)
|
||||
if err != nil {
|
||||
log.Debug("Whois: error: %s IP:%s", err, ip)
|
||||
return data
|
||||
}
|
||||
|
||||
log.Debug("Whois: IP:%s response: %d bytes", ip, len(resp))
|
||||
|
||||
m := whoisParse(resp)
|
||||
|
||||
keys := []string{"orgname", "country", "city"}
|
||||
for _, k := range keys {
|
||||
v, found := m[k]
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
pair := []string{k, v}
|
||||
data = append(data, pair)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// Begin - begin requesting WHOIS info
|
||||
func (w *Whois) Begin(ip string) {
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := w.ipAddrs.Get([]byte(ip))
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
return
|
||||
}
|
||||
// TTL expired
|
||||
}
|
||||
expire = make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
||||
_ = w.ipAddrs.Set([]byte(ip), expire)
|
||||
|
||||
log.Debug("Whois: adding %s", ip)
|
||||
select {
|
||||
case w.ipChan <- ip:
|
||||
//
|
||||
default:
|
||||
log.Debug("Whois: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Get IP address from channel; get WHOIS info; associate info with a client
|
||||
func (w *Whois) workerLoop() {
|
||||
for {
|
||||
var ip string
|
||||
ip = <-w.ipChan
|
||||
|
||||
info := w.process(ip)
|
||||
if len(info) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
w.clients.SetWhoisInfo(ip, info)
|
||||
}
|
||||
}
|
||||
28
internal/home/whois_test.go
Normal file
28
internal/home/whois_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func prepareTestDNSServer() error {
|
||||
config.DNS.Port = 1234
|
||||
Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
|
||||
conf := &dnsforward.ServerConfig{}
|
||||
conf.UpstreamDNS = []string{"8.8.8.8"}
|
||||
return Context.dnsServer.Prepare(conf)
|
||||
}
|
||||
|
||||
func TestWhois(t *testing.T) {
|
||||
assert.Nil(t, prepareTestDNSServer())
|
||||
|
||||
w := Whois{timeoutMsec: 5000}
|
||||
resp, err := w.queryAll("8.8.8.8")
|
||||
assert.Nil(t, err)
|
||||
m := whoisParse(resp)
|
||||
assert.Equal(t, "Google LLC", m["orgname"])
|
||||
assert.Equal(t, "US", m["country"])
|
||||
assert.Equal(t, "Mountain View", m["city"])
|
||||
}
|
||||
Reference in New Issue
Block a user