* move ./*.go files into ./home/ directory
This commit is contained in:
482
home/clients.go
Normal file
482
home/clients.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Client information
|
||||
type Client struct {
|
||||
IP string
|
||||
MAC string
|
||||
Name string
|
||||
UseOwnSettings bool // false: use global settings
|
||||
FilteringEnabled bool
|
||||
SafeSearchEnabled bool
|
||||
SafeBrowsingEnabled bool
|
||||
ParentalEnabled bool
|
||||
}
|
||||
|
||||
type clientJSON struct {
|
||||
IP string `json:"ip"`
|
||||
MAC string `json:"mac"`
|
||||
Name string `json:"name"`
|
||||
UseGlobalSettings bool `json:"use_global_settings"`
|
||||
FilteringEnabled bool `json:"filtering_enabled"`
|
||||
ParentalEnabled bool `json:"parental_enabled"`
|
||||
SafeSearchEnabled bool `json:"safebrowsing_enabled"`
|
||||
SafeBrowsingEnabled bool `json:"safesearch_enabled"`
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
|
||||
const (
|
||||
ClientSourceHostsFile clientSource = 0 // from /etc/hosts
|
||||
ClientSourceRDNS clientSource = 1 // from rDNS
|
||||
)
|
||||
|
||||
// ClientHost information
|
||||
type ClientHost struct {
|
||||
Host string
|
||||
Source clientSource
|
||||
}
|
||||
|
||||
type clientsContainer struct {
|
||||
list map[string]*Client
|
||||
ipIndex map[string]*Client
|
||||
ipHost map[string]ClientHost // IP -> Hostname
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
var clients clientsContainer
|
||||
|
||||
// Initialize clients container
|
||||
func clientsInit() {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.ipIndex = make(map[string]*Client)
|
||||
clients.ipHost = make(map[string]ClientHost)
|
||||
|
||||
clientsAddFromHostsFile()
|
||||
}
|
||||
|
||||
func clientsGetList() map[string]*Client {
|
||||
return clients.list
|
||||
}
|
||||
|
||||
func clientExists(ip string) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.ipIndex[ip]
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok = clients.ipHost[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Search for a client by IP
|
||||
func clientFind(ip string) (Client, bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.ipIndex[ip]
|
||||
if ok {
|
||||
return *c, true
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
if len(c.MAC) != 0 {
|
||||
mac, err := net.ParseMAC(c.MAC)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ipAddr := dhcpServer.FindIPbyMAC(mac)
|
||||
if ipAddr == nil {
|
||||
continue
|
||||
}
|
||||
if ip == ipAddr.String() {
|
||||
return *c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Client{}, false
|
||||
}
|
||||
|
||||
// Check if Client object's fields are correct
|
||||
func clientCheck(c *Client) error {
|
||||
if len(c.Name) == 0 {
|
||||
return fmt.Errorf("Invalid Name")
|
||||
}
|
||||
|
||||
if (len(c.IP) == 0 && len(c.MAC) == 0) ||
|
||||
(len(c.IP) != 0 && len(c.MAC) != 0) {
|
||||
return fmt.Errorf("IP or MAC required")
|
||||
}
|
||||
|
||||
if len(c.IP) != 0 {
|
||||
ip := net.ParseIP(c.IP)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("Invalid IP")
|
||||
}
|
||||
c.IP = ip.String()
|
||||
} else {
|
||||
_, err := net.ParseMAC(c.MAC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid MAC: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a new client object
|
||||
// Return true: success; false: client exists.
|
||||
func clientAdd(c Client) (bool, error) {
|
||||
e := clientCheck(&c)
|
||||
if e != nil {
|
||||
return false, e
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok := clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check IP index
|
||||
if len(c.IP) != 0 {
|
||||
c2, ok := clients.ipIndex[c.IP]
|
||||
if ok {
|
||||
return false, fmt.Errorf("Another client uses the same IP address: %s", c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
clients.list[c.Name] = &c
|
||||
if len(c.IP) != 0 {
|
||||
clients.ipIndex[c.IP] = &c
|
||||
}
|
||||
|
||||
log.Tracef("'%s': '%s' | '%s' -> [%d]", c.Name, c.IP, c.MAC, len(clients.list))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Remove a client
|
||||
func clientDel(name string) bool {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.list[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(clients.list, name)
|
||||
delete(clients.ipIndex, c.IP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Update a client
|
||||
func clientUpdate(name string, c Client) error {
|
||||
err := clientCheck(&c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
old, ok := clients.list[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("Client not found")
|
||||
}
|
||||
|
||||
// check Name index
|
||||
if old.Name != c.Name {
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return fmt.Errorf("Client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
// check IP index
|
||||
if old.IP != c.IP && len(c.IP) != 0 {
|
||||
c2, ok := clients.ipIndex[c.IP]
|
||||
if ok {
|
||||
return fmt.Errorf("Another client uses the same IP address: %s", c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
if old.Name != c.Name {
|
||||
delete(clients.list, old.Name)
|
||||
}
|
||||
clients.list[c.Name] = &c
|
||||
|
||||
// update IP index
|
||||
if old.IP != c.IP {
|
||||
delete(clients.ipIndex, old.IP)
|
||||
}
|
||||
if len(c.IP) != 0 {
|
||||
clients.ipIndex[c.IP] = &c
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func clientAddHost(ip, host string, source clientSource) (bool, error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check index
|
||||
_, ok := clients.ipHost[ip]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
clients.ipHost[ip] = ClientHost{
|
||||
Host: host,
|
||||
Source: source,
|
||||
}
|
||||
log.Tracef("'%s': '%s' -> [%d]", host, ip, len(clients.ipHost))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Parse system 'hosts' file and fill clients array
|
||||
func clientsAddFromHostsFile() {
|
||||
hostsFn := "/etc/hosts"
|
||||
if runtime.GOOS == "windows" {
|
||||
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
||||
}
|
||||
|
||||
d, e := ioutil.ReadFile(hostsFn)
|
||||
if e != nil {
|
||||
log.Info("Can't read file %s: %v", hostsFn, e)
|
||||
return
|
||||
}
|
||||
|
||||
lines := strings.Split(string(d), "\n")
|
||||
n := 0
|
||||
for _, ln := range lines {
|
||||
ln = strings.TrimSpace(ln)
|
||||
if len(ln) == 0 || ln[0] == '#' {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(ln)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
|
||||
if e != nil {
|
||||
log.Tracef("%s", e)
|
||||
}
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Added %d client aliases from %s", n, hostsFn)
|
||||
}
|
||||
|
||||
type clientHostJSON struct {
|
||||
IP string `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
Clients []clientJSON `json:"clients"`
|
||||
AutoClients []clientHostJSON `json:"auto_clients"`
|
||||
}
|
||||
|
||||
// respond with information about configured clients
|
||||
func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
data := clientListJSON{}
|
||||
|
||||
clients.lock.Lock()
|
||||
for _, c := range clients.list {
|
||||
cj := clientJSON{
|
||||
IP: c.IP,
|
||||
MAC: c.MAC,
|
||||
Name: c.Name,
|
||||
UseGlobalSettings: !c.UseOwnSettings,
|
||||
FilteringEnabled: c.FilteringEnabled,
|
||||
ParentalEnabled: c.ParentalEnabled,
|
||||
SafeSearchEnabled: c.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
}
|
||||
|
||||
if len(c.MAC) != 0 {
|
||||
hwAddr, _ := net.ParseMAC(c.MAC)
|
||||
ipAddr := dhcpServer.FindIPbyMAC(hwAddr)
|
||||
if ipAddr != nil {
|
||||
cj.IP = ipAddr.String()
|
||||
}
|
||||
}
|
||||
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
for ip, ch := range clients.ipHost {
|
||||
cj := clientHostJSON{
|
||||
IP: ip,
|
||||
Name: ch.Host,
|
||||
}
|
||||
cj.Source = "etc/hosts"
|
||||
if ch.Source == ClientSourceRDNS {
|
||||
cj.Source = "rDNS"
|
||||
}
|
||||
data.AutoClients = append(data.AutoClients, cj)
|
||||
}
|
||||
clients.lock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Convert JSON object to Client object
|
||||
func jsonToClient(cj clientJSON) (*Client, error) {
|
||||
c := Client{
|
||||
IP: cj.IP,
|
||||
MAC: cj.MAC,
|
||||
Name: cj.Name,
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeSearchEnabled: cj.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
func handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
c, err := jsonToClient(cj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
ok, err := clientAdd(*c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
httpError(w, http.StatusBadRequest, "Client already exists")
|
||||
return
|
||||
}
|
||||
|
||||
_ = writeAllConfigsAndReloadDNS()
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
// Remove client
|
||||
func handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
cj := clientJSON{}
|
||||
err = json.Unmarshal(body, &cj)
|
||||
if err != nil || len(cj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !clientDel(cj.Name) {
|
||||
httpError(w, http.StatusBadRequest, "Client not found")
|
||||
return
|
||||
}
|
||||
|
||||
_ = writeAllConfigsAndReloadDNS()
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
type clientUpdateJSON struct {
|
||||
Name string `json:"name"`
|
||||
Data clientJSON `json:"data"`
|
||||
}
|
||||
|
||||
// Update client's properties
|
||||
func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var dj clientUpdateJSON
|
||||
err = json.Unmarshal(body, &dj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
if len(dj.Name) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "Invalid request")
|
||||
return
|
||||
}
|
||||
|
||||
c, err := jsonToClient(dj.Data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clientUpdate(dj.Name, *c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = writeAllConfigsAndReloadDNS()
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
func RegisterClientsHandlers() {
|
||||
http.HandleFunc("/control/clients", postInstall(optionalAuth(ensureGET(handleGetClients))))
|
||||
http.HandleFunc("/control/clients/add", postInstall(optionalAuth(ensurePOST(handleAddClient))))
|
||||
http.HandleFunc("/control/clients/delete", postInstall(optionalAuth(ensurePOST(handleDelClient))))
|
||||
http.HandleFunc("/control/clients/update", postInstall(optionalAuth(ensurePOST(handleUpdateClient))))
|
||||
}
|
||||
122
home/clients_test.go
Normal file
122
home/clients_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package home
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
var c Client
|
||||
var e error
|
||||
var b bool
|
||||
|
||||
clientsInit()
|
||||
|
||||
// add
|
||||
c = Client{
|
||||
IP: "1.1.1.1",
|
||||
Name: "client1",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAdd #1")
|
||||
}
|
||||
|
||||
// add #2
|
||||
c = Client{
|
||||
IP: "2.2.2.2",
|
||||
Name: "client2",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAdd #2")
|
||||
}
|
||||
|
||||
// failed add - name in use
|
||||
c = Client{
|
||||
IP: "1.2.3.5",
|
||||
Name: "client1",
|
||||
}
|
||||
b, _ = clientAdd(c)
|
||||
if b {
|
||||
t.Fatalf("clientAdd - name in use")
|
||||
}
|
||||
|
||||
// failed add - ip in use
|
||||
c = Client{
|
||||
IP: "2.2.2.2",
|
||||
Name: "client3",
|
||||
}
|
||||
b, e = clientAdd(c)
|
||||
if b || e == nil {
|
||||
t.Fatalf("clientAdd - ip in use")
|
||||
}
|
||||
|
||||
// get
|
||||
if clientExists("1.2.3.4") {
|
||||
t.Fatalf("clientExists")
|
||||
}
|
||||
if !clientExists("1.1.1.1") {
|
||||
t.Fatalf("clientExists #1")
|
||||
}
|
||||
if !clientExists("2.2.2.2") {
|
||||
t.Fatalf("clientExists #2")
|
||||
}
|
||||
|
||||
// failed update - no such name
|
||||
c.IP = "1.2.3.0"
|
||||
c.Name = "client3"
|
||||
if clientUpdate("client3", c) == nil {
|
||||
t.Fatalf("clientUpdate")
|
||||
}
|
||||
|
||||
// failed update - name in use
|
||||
c.IP = "1.2.3.0"
|
||||
c.Name = "client2"
|
||||
if clientUpdate("client1", c) == nil {
|
||||
t.Fatalf("clientUpdate - name in use")
|
||||
}
|
||||
|
||||
// failed update - ip in use
|
||||
c.IP = "2.2.2.2"
|
||||
c.Name = "client1"
|
||||
if clientUpdate("client1", c) == nil {
|
||||
t.Fatalf("clientUpdate - ip in use")
|
||||
}
|
||||
|
||||
// update
|
||||
c.IP = "1.1.1.2"
|
||||
c.Name = "client1"
|
||||
if clientUpdate("client1", c) != nil {
|
||||
t.Fatalf("clientUpdate")
|
||||
}
|
||||
|
||||
// get after update
|
||||
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
|
||||
t.Fatalf("clientExists - get after update")
|
||||
}
|
||||
|
||||
// failed remove - no such name
|
||||
if clientDel("client3") {
|
||||
t.Fatalf("clientDel - no such name")
|
||||
}
|
||||
|
||||
// remove
|
||||
if !clientDel("client1") || clientExists("1.1.1.2") {
|
||||
t.Fatalf("clientDel")
|
||||
}
|
||||
|
||||
// add host client
|
||||
b, e = clientAddHost("1.1.1.1", "host", ClientSourceHostsFile)
|
||||
if !b || e != nil {
|
||||
t.Fatalf("clientAddHost")
|
||||
}
|
||||
|
||||
// failed add - ip exists
|
||||
b, e = clientAddHost("1.1.1.1", "host", ClientSourceHostsFile)
|
||||
if b || e != nil {
|
||||
t.Fatalf("clientAddHost - ip exists")
|
||||
}
|
||||
|
||||
// get
|
||||
if !clientExists("1.1.1.1") {
|
||||
t.Fatalf("clientAddHost")
|
||||
}
|
||||
}
|
||||
320
home/config.go
Normal file
320
home/config.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// logSettings
|
||||
type logSettings struct {
|
||||
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
}
|
||||
|
||||
type clientObject struct {
|
||||
Name string `yaml:"name"`
|
||||
IP string `yaml:"ip"`
|
||||
MAC string `yaml:"mac"`
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
// Raw file data to avoid re-reading of configuration file
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
|
||||
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
|
||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
AuthName string `yaml:"auth_name"` // AuthName is the basic auth username
|
||||
AuthPass string `yaml:"auth_pass"` // AuthPass is the basic auth password
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfig `yaml:"tls"`
|
||||
Filters []filter `yaml:"filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
||||
// Note: this array is filled only before file read/write and then it's cleared
|
||||
Clients []clientObject `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
|
||||
sync.RWMutex `yaml:"-"`
|
||||
|
||||
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type dnsConfig struct {
|
||||
BindHost string `yaml:"bind_host"`
|
||||
Port int `yaml:"port"`
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
|
||||
var defaultDNS = []string{"https://dns.cloudflare.com/dns-query"}
|
||||
var defaultBootstrap = []string{"1.1.1.1"}
|
||||
|
||||
type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled
|
||||
|
||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
}
|
||||
|
||||
// field ordering is not important -- these are for API and are recalculated on each run
|
||||
type tlsConfigStatus struct {
|
||||
ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||
ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||
Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||
Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||
NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||
NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||
DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||
|
||||
// key status
|
||||
ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||
KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||
|
||||
// is usable? set by validator
|
||||
ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||
|
||||
// warnings
|
||||
WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type tlsConfig struct {
|
||||
tlsConfigSettings `yaml:",inline" json:",inline"`
|
||||
tlsConfigStatus `yaml:"-" json:",inline"`
|
||||
}
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
var config = configuration{
|
||||
ourConfigFilename: "AdGuardHome.yaml",
|
||||
BindPort: 3000,
|
||||
BindHost: "0.0.0.0",
|
||||
DNS: dnsConfig{
|
||||
BindHost: "0.0.0.0",
|
||||
Port: 53,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of dnsfilter features
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
BlockingMode: "nxdomain", // mode how to answer filtered requests
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
BootstrapDNS: defaultBootstrap,
|
||||
AllServers: false,
|
||||
},
|
||||
UpstreamDNS: defaultDNS,
|
||||
},
|
||||
TLS: tlsConfig{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
PortHTTPS: 443,
|
||||
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
|
||||
},
|
||||
},
|
||||
Filters: []filter{
|
||||
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
|
||||
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
|
||||
{Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
|
||||
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
|
||||
},
|
||||
DHCP: dhcpd.ServerConfig{
|
||||
LeaseDuration: 86400,
|
||||
ICMPTimeout: 1000,
|
||||
},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// init initializes default configuration for the current OS&ARCH
|
||||
func init() {
|
||||
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
||||
// Use plain DNS on MIPS, encryption is too slow
|
||||
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
||||
// also change the default config
|
||||
config.DNS.UpstreamDNS = defaultDNS
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
func (c *configuration) getConfigFilename() string {
|
||||
configFile, err := filepath.EvalSymlinks(config.ourConfigFilename)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Error("unexpected error while config file path evaluation: %s", err)
|
||||
}
|
||||
configFile = config.ourConfigFilename
|
||||
}
|
||||
if !filepath.IsAbs(configFile) {
|
||||
configFile = filepath.Join(config.ourWorkingDir, configFile)
|
||||
}
|
||||
return configFile
|
||||
}
|
||||
|
||||
// getLogSettings reads logging settings from the config file.
|
||||
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
|
||||
func getLogSettings() logSettings {
|
||||
l := logSettings{}
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return l
|
||||
}
|
||||
err = yaml.Unmarshal(yamlFile, &l)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get logging settings from the configuration: %s", err)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// parseConfig loads configuration from the YAML file
|
||||
func parseConfig() error {
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Reading config file: %s", configFile)
|
||||
yamlFile, err := readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.fileData = nil
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't parse config file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cy := range config.Clients {
|
||||
cli := Client{
|
||||
Name: cy.Name,
|
||||
IP: cy.IP,
|
||||
MAC: cy.MAC,
|
||||
UseOwnSettings: !cy.UseGlobalSettings,
|
||||
FilteringEnabled: cy.FilteringEnabled,
|
||||
ParentalEnabled: cy.ParentalEnabled,
|
||||
SafeSearchEnabled: cy.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
||||
}
|
||||
_, err = clientAdd(cli)
|
||||
if err != nil {
|
||||
log.Tracef("clientAdd: %s", err)
|
||||
}
|
||||
}
|
||||
config.Clients = nil
|
||||
|
||||
// Deduplicate filters
|
||||
deduplicateFilters()
|
||||
|
||||
updateUniqueFilterID(config.Filters)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readConfigFile reads config file contents if it exists
|
||||
func readConfigFile() ([]byte, error) {
|
||||
if len(config.fileData) != 0 {
|
||||
return config.fileData, nil
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
d, err := ioutil.ReadFile(configFile)
|
||||
if err != nil {
|
||||
log.Error("Couldn't read config file %s: %s", configFile, err)
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
clientsList := clientsGetList()
|
||||
for _, cli := range clientsList {
|
||||
ip := cli.IP
|
||||
if len(cli.MAC) != 0 {
|
||||
ip = ""
|
||||
}
|
||||
cy := clientObject{
|
||||
Name: cli.Name,
|
||||
IP: ip,
|
||||
MAC: cli.MAC,
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
ParentalEnabled: cli.ParentalEnabled,
|
||||
SafeSearchEnabled: cli.SafeSearchEnabled,
|
||||
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
||||
}
|
||||
config.Clients = append(config.Clients, cy)
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
config.Clients = nil
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
return err
|
||||
}
|
||||
err = file.SafeWrite(configFile, yamlText)
|
||||
if err != nil {
|
||||
log.Error("Couldn't save YAML config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAllConfigs() error {
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Error("Couldn't write config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
userFilter := userFilter()
|
||||
err = userFilter.save()
|
||||
if err != nil {
|
||||
log.Error("Couldn't save the user filter: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1002
home/control.go
Normal file
1002
home/control.go
Normal file
File diff suppressed because it is too large
Load Diff
87
home/control_access.go
Normal file
87
home/control_access.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type accessListJSON struct {
|
||||
AllowedClients []string `json:"allowed_clients"`
|
||||
DisallowedClients []string `json:"disallowed_clients"`
|
||||
BlockedHosts []string `json:"blocked_hosts"`
|
||||
}
|
||||
|
||||
func handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
controlLock.Lock()
|
||||
j := accessListJSON{
|
||||
AllowedClients: config.DNS.AllowedClients,
|
||||
DisallowedClients: config.DNS.DisallowedClients,
|
||||
BlockedHosts: config.DNS.BlockedHosts,
|
||||
}
|
||||
controlLock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func checkIPCIDRArray(src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
j := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&j)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = checkIPCIDRArray(j.AllowedClients)
|
||||
if err == nil {
|
||||
err = checkIPCIDRArray(j.DisallowedClients)
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
config.Lock()
|
||||
config.DNS.AllowedClients = j.AllowedClients
|
||||
config.DNS.DisallowedClients = j.DisallowedClients
|
||||
config.DNS.BlockedHosts = j.BlockedHosts
|
||||
config.Unlock()
|
||||
|
||||
log.Tracef("Update access lists: %d, %d, %d",
|
||||
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
|
||||
|
||||
err = writeAllConfigsAndReloadDNS()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
}
|
||||
278
home/control_install.go
Normal file
278
home/control_install.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type firstRunData struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces map[string]interface{} `json:"interfaces"`
|
||||
}
|
||||
|
||||
// Get initial installation settings
|
||||
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
data := firstRunData{}
|
||||
data.WebPort = 80
|
||||
data.DNSPort = 53
|
||||
|
||||
ifaces, err := getValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
data.Interfaces = make(map[string]interface{})
|
||||
for _, iface := range ifaces {
|
||||
data.Interfaces[iface.Name] = iface
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type checkConfigReqEnt struct {
|
||||
Port int `json:"port"`
|
||||
IP string `json:"ip"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
type checkConfigReq struct {
|
||||
Web checkConfigReqEnt `json:"web"`
|
||||
DNS checkConfigReqEnt `json:"dns"`
|
||||
}
|
||||
|
||||
type checkConfigRespEnt struct {
|
||||
Status string `json:"status"`
|
||||
CanAutofix bool `json:"can_autofix"`
|
||||
}
|
||||
type checkConfigResp struct {
|
||||
Web checkConfigRespEnt `json:"web"`
|
||||
DNS checkConfigRespEnt `json:"dns"`
|
||||
}
|
||||
|
||||
// Check if ports are available, respond with results
|
||||
func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
reqData := checkConfigReq{}
|
||||
respData := checkConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
|
||||
err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
||||
if err != nil {
|
||||
respData.Web.Status = fmt.Sprintf("%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqData.DNS.Port != 0 {
|
||||
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
|
||||
if errorIsAddrInUse(err) {
|
||||
canAutofix := checkDNSStubListener()
|
||||
if canAutofix && reqData.DNS.Autofix {
|
||||
|
||||
err = disableDNSStubListener()
|
||||
if err != nil {
|
||||
log.Error("Couldn't disable DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
respData.DNS.CanAutofix = canAutofix
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
respData.DNS.Status = fmt.Sprintf("%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(respData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if DNSStubListener is active
|
||||
func checkDNSStubListener() bool {
|
||||
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Error("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return false
|
||||
}
|
||||
|
||||
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Error("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Deactivate DNSStubListener
|
||||
func disableDNSStubListener() error {
|
||||
cmd := exec.Command("sed", "-r", "-i.orig", "s/#?DNSStubListener=yes/DNSStubListener=no/g", "/etc/systemd/resolved.conf")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("process %s exited with an error: %d",
|
||||
cmd.Path, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
cmd = exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("process %s exited with an error: %d",
|
||||
cmd.Path, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
type applyConfigReq struct {
|
||||
Web applyConfigReqEnt `json:"web"`
|
||||
DNS applyConfigReqEnt `json:"dns"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Copy installation parameters between two configuration objects
|
||||
func copyInstallSettings(dst *configuration, src *configuration) {
|
||||
dst.BindHost = src.BindHost
|
||||
dst.BindPort = src.BindPort
|
||||
dst.DNS.BindHost = src.DNS.BindHost
|
||||
dst.DNS.Port = src.DNS.Port
|
||||
dst.AuthName = src.AuthName
|
||||
dst.AuthPass = src.AuthPass
|
||||
}
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
newSettings := applyConfigReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&newSettings)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'configure' JSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if newSettings.Web.Port == 0 || newSettings.DNS.Port == 0 {
|
||||
httpError(w, http.StatusBadRequest, "port value can't be 0")
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTP := true
|
||||
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
|
||||
// no need to rebind
|
||||
restartHTTP = false
|
||||
}
|
||||
|
||||
// validate that hosts and ports are bindable
|
||||
if restartHTTP {
|
||||
err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
|
||||
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var curConfig configuration
|
||||
copyInstallSettings(&curConfig, &config)
|
||||
|
||||
config.firstRun = false
|
||||
config.BindHost = newSettings.Web.IP
|
||||
config.BindPort = newSettings.Web.Port
|
||||
config.DNS.BindHost = newSettings.DNS.IP
|
||||
config.DNS.Port = newSettings.DNS.Port
|
||||
config.AuthName = newSettings.Username
|
||||
config.AuthPass = newSettings.Password
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
config.firstRun = true
|
||||
copyInstallSettings(&config, &curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = config.write()
|
||||
if err != nil {
|
||||
config.firstRun = true
|
||||
copyInstallSettings(&config, &curConfig)
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go refreshFiltersIfNecessary(false)
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTP {
|
||||
go func() {
|
||||
httpServer.Shutdown(context.TODO())
|
||||
}()
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
func registerInstallHandlers() {
|
||||
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses)))
|
||||
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig)))
|
||||
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure)))
|
||||
}
|
||||
153
home/control_test.go
Normal file
153
home/control_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Tests performed:
|
||||
. Bad certificate
|
||||
. Bad private key
|
||||
. Valid certificate & private key */
|
||||
func TestValidateCertificates(t *testing.T) {
|
||||
var data tlsConfigStatus
|
||||
|
||||
// bad cert
|
||||
data = validateCertificates("bad cert", "", "")
|
||||
if !(data.WarningValidation != "" &&
|
||||
!data.ValidCert &&
|
||||
!data.ValidChain) {
|
||||
t.Fatalf("bad cert: validateCertificates(): %v", data)
|
||||
}
|
||||
|
||||
// bad priv key
|
||||
data = validateCertificates("", "bad priv key", "")
|
||||
if !(data.WarningValidation != "" &&
|
||||
!data.ValidKey) {
|
||||
t.Fatalf("bad priv key: validateCertificates(): %v", data)
|
||||
}
|
||||
|
||||
// valid cert & priv key
|
||||
CertificateChain := `-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
|
||||
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
|
||||
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
|
||||
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
|
||||
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
|
||||
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
|
||||
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||
-----END CERTIFICATE-----`
|
||||
PrivateKey := `-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
|
||||
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
|
||||
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
|
||||
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
|
||||
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
|
||||
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
|
||||
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
|
||||
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
|
||||
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||
kXS9jgARhhiWXJrk
|
||||
-----END PRIVATE KEY-----`
|
||||
data = validateCertificates(CertificateChain, PrivateKey, "")
|
||||
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
|
||||
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
|
||||
if !(data.WarningValidation != "" /* self signed */ &&
|
||||
data.ValidCert &&
|
||||
!data.ValidChain &&
|
||||
data.ValidKey &&
|
||||
data.KeyType == "RSA" &&
|
||||
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.NotBefore == notBefore &&
|
||||
data.NotAfter == notAfter &&
|
||||
// data.DNSNames[0] == &&
|
||||
data.ValidPair) {
|
||||
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
invalidUpstreams := []string{"1.2.3.4.5",
|
||||
"123.3.7m",
|
||||
"htttps://google.com/dns-query",
|
||||
"[/host.com]tls://dns.adguard.com",
|
||||
"[host.ru]#",
|
||||
}
|
||||
|
||||
validDefaultUpstreams := []string{"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
}
|
||||
|
||||
validUpstreams := []string{"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
}
|
||||
for _, u := range invalidUpstreams {
|
||||
_, err := validateUpstream(u)
|
||||
if err == nil {
|
||||
t.Fatalf("upstream %s is invalid but it pass through validation", u)
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range validDefaultUpstreams {
|
||||
defaultUpstream, err := validateUpstream(u)
|
||||
if err != nil {
|
||||
t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
|
||||
}
|
||||
if !defaultUpstream {
|
||||
t.Fatalf("upstream %s is default one!", u)
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range validUpstreams {
|
||||
defaultUpstream, err := validateUpstream(u)
|
||||
if err != nil {
|
||||
t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
|
||||
}
|
||||
if defaultUpstream {
|
||||
t.Fatalf("upstream %s is default one!", u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
// Set of valid upstreams. There is no default upstream specified
|
||||
upstreamsSet := []string{"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
}
|
||||
err := validateUpstreams(upstreamsSet)
|
||||
if err == nil {
|
||||
t.Fatalf("there is no default upstream")
|
||||
}
|
||||
|
||||
// Let's add default upstream
|
||||
upstreamsSet = append(upstreamsSet, "8.8.8.8")
|
||||
err = validateUpstreams(upstreamsSet)
|
||||
if err != nil {
|
||||
t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err)
|
||||
}
|
||||
|
||||
// Let's add invalid upstream
|
||||
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
|
||||
err = validateUpstreams(upstreamsSet)
|
||||
if err == nil {
|
||||
t.Fatalf("there is an invalid upstream in set, but it pass through validation")
|
||||
}
|
||||
}
|
||||
338
home/control_tls.go
Normal file
338
home/control_tls.go
Normal file
@@ -0,0 +1,338 @@
|
||||
// Control: TLS configuring handlers
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
// RegisterTLSHandlers registers HTTP handlers for TLS configuration
|
||||
func RegisterTLSHandlers() {
|
||||
http.HandleFunc("/control/tls/status", postInstall(optionalAuth(ensureGET(handleTLSStatus))))
|
||||
http.HandleFunc("/control/tls/configure", postInstall(optionalAuth(ensurePOST(handleTLSConfigure))))
|
||||
http.HandleFunc("/control/tls/validate", postInstall(optionalAuth(ensurePOST(handleTLSValidate))))
|
||||
}
|
||||
|
||||
func handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
marshalTLS(w, config.TLS)
|
||||
}
|
||||
|
||||
func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
data, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
data.tlsConfigStatus = validateCertificates(data.CertificateChain, data.PrivateKey, data.ServerName)
|
||||
marshalTLS(w, data)
|
||||
}
|
||||
|
||||
func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
data, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if port is available
|
||||
// BUT: if we are already using this port, no need
|
||||
alreadyRunning := false
|
||||
if httpsServer.server != nil {
|
||||
alreadyRunning = true
|
||||
}
|
||||
if !alreadyRunning {
|
||||
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
restartHTTPS := false
|
||||
data.tlsConfigStatus = validateCertificates(data.CertificateChain, data.PrivateKey, data.ServerName)
|
||||
if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) {
|
||||
log.Printf("tls config settings have changed, will restart HTTPS server")
|
||||
restartHTTPS = true
|
||||
}
|
||||
config.TLS = data
|
||||
err = writeAllConfigsAndReloadDNS()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
|
||||
return
|
||||
}
|
||||
marshalTLS(w, data)
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||
httpsServer.cond.L.Lock()
|
||||
httpsServer.cond.Broadcast()
|
||||
if httpsServer.server != nil {
|
||||
httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
httpsServer.cond.L.Unlock()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCertChain(data *tlsConfigStatus, certChain string, serverName string) error {
|
||||
log.Tracef("got certificate: %s", certChain)
|
||||
|
||||
// now do a more extended validation
|
||||
var certs []*pem.Block // PEM-encoded certificates
|
||||
var skippedBytes []string // skipped bytes
|
||||
|
||||
pemblock := []byte(certChain)
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
decoded, pemblock = pem.Decode(pemblock)
|
||||
if decoded == nil {
|
||||
break
|
||||
}
|
||||
if decoded.Type == "CERTIFICATE" {
|
||||
certs = append(certs, decoded)
|
||||
} else {
|
||||
skippedBytes = append(skippedBytes, decoded.Type)
|
||||
}
|
||||
}
|
||||
|
||||
var parsedCerts []*x509.Certificate
|
||||
|
||||
for _, cert := range certs {
|
||||
parsed, err := x509.ParseCertificate(cert.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
parsedCerts = append(parsedCerts, parsed)
|
||||
}
|
||||
|
||||
if len(parsedCerts) == 0 {
|
||||
data.WarningValidation = fmt.Sprintf("You have specified an empty certificate")
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidCert = true
|
||||
|
||||
// spew.Dump(parsedCerts)
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: serverName,
|
||||
}
|
||||
|
||||
log.Printf("number of certs - %d", len(parsedCerts))
|
||||
if len(parsedCerts) > 1 {
|
||||
// set up an intermediate
|
||||
pool := x509.NewCertPool()
|
||||
for _, cert := range parsedCerts[1:] {
|
||||
log.Printf("got an intermediate cert")
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
opts.Intermediates = pool
|
||||
}
|
||||
|
||||
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem
|
||||
mainCert := parsedCerts[0]
|
||||
_, err := mainCert.Verify(opts)
|
||||
if err != nil {
|
||||
// let self-signed certs through
|
||||
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err)
|
||||
} else {
|
||||
data.ValidChain = true
|
||||
}
|
||||
// spew.Dump(chains)
|
||||
|
||||
// update status
|
||||
if mainCert != nil {
|
||||
notAfter := mainCert.NotAfter
|
||||
data.Subject = mainCert.Subject.String()
|
||||
data.Issuer = mainCert.Issuer.String()
|
||||
data.NotAfter = notAfter
|
||||
data.NotBefore = mainCert.NotBefore
|
||||
data.DNSNames = mainCert.DNSNames
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePkey(data *tlsConfigStatus, pkey string) error {
|
||||
// now do a more extended validation
|
||||
var key *pem.Block // PEM-encoded certificates
|
||||
var skippedBytes []string // skipped bytes
|
||||
|
||||
// go through all pem blocks, but take first valid pem block and drop the rest
|
||||
pemblock := []byte(pkey)
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
decoded, pemblock = pem.Decode(pemblock)
|
||||
if decoded == nil {
|
||||
break
|
||||
}
|
||||
if decoded.Type == "PRIVATE KEY" || strings.HasSuffix(decoded.Type, " PRIVATE KEY") {
|
||||
key = decoded
|
||||
break
|
||||
} else {
|
||||
skippedBytes = append(skippedBytes, decoded.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
data.WarningValidation = "No valid keys were found"
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
// parse the decoded key
|
||||
_, keytype, err := parsePrivateKey(key.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidKey = true
|
||||
data.KeyType = keytype
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process certificate data and its private key.
|
||||
// All parameters are optional.
|
||||
// On error, return partially set object
|
||||
// with 'WarningValidation' field containing error description.
|
||||
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus {
|
||||
var data tlsConfigStatus
|
||||
|
||||
// check only public certificate separately from the key
|
||||
if certChain != "" {
|
||||
if verifyCertChain(&data, certChain, serverName) != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// validate private key (right now the only validation possible is just parsing it)
|
||||
if pkey != "" {
|
||||
if validatePkey(&data, pkey) != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// if both are set, validate both in unison
|
||||
if pkey != "" && certChain != "" {
|
||||
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err)
|
||||
return data
|
||||
}
|
||||
data.ValidPair = true
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
|
||||
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
|
||||
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
|
||||
func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
||||
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
||||
return key, "RSA", nil
|
||||
}
|
||||
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return key, "RSA", nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return key, "ECDSA", nil
|
||||
default:
|
||||
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
|
||||
}
|
||||
}
|
||||
|
||||
if key, err := x509.ParseECPrivateKey(der); err == nil {
|
||||
return key, "ECDSA", nil
|
||||
}
|
||||
|
||||
return nil, "", errors.New("tls: failed to parse private key")
|
||||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
func unmarshalTLS(r *http.Request) (tlsConfig, error) {
|
||||
data := tlsConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
|
||||
}
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
certPEM, err := base64.StdEncoding.DecodeString(data.CertificateChain)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to base64-decode certificate chain")
|
||||
}
|
||||
data.CertificateChain = string(certPEM)
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
keyPEM, err := base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, errorx.Decorate(err, "Failed to base64-decode private key")
|
||||
}
|
||||
|
||||
data.PrivateKey = string(keyPEM)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func marshalTLS(w http.ResponseWriter, data tlsConfig) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
||||
data.CertificateChain = encoded
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.PrivateKey))
|
||||
data.PrivateKey = encoded
|
||||
}
|
||||
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
511
home/control_update.go
Normal file
511
home/control_update.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Convert version.json data to our JSON response
|
||||
func getVersionResp(data []byte) []byte {
|
||||
versionJSON := make(map[string]interface{})
|
||||
err := json.Unmarshal(data, &versionJSON)
|
||||
if err != nil {
|
||||
log.Error("version.json: %s", err)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
ret := make(map[string]interface{})
|
||||
ret["can_autoupdate"] = false
|
||||
|
||||
var ok1, ok2, ok3 bool
|
||||
ret["new_version"], ok1 = versionJSON["version"].(string)
|
||||
ret["announcement"], ok2 = versionJSON["announcement"].(string)
|
||||
ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string)
|
||||
selfUpdateMinVersion, ok4 := versionJSON["selfupdate_min_version"].(string)
|
||||
if !ok1 || !ok2 || !ok3 || !ok4 {
|
||||
log.Error("version.json: invalid data")
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
_, ok := versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)]
|
||||
if ok && ret["new_version"] != VersionString && VersionString >= selfUpdateMinVersion {
|
||||
ret["can_autoupdate"] = true
|
||||
}
|
||||
|
||||
d, _ := json.Marshal(ret)
|
||||
return d
|
||||
}
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
if config.disableUpdate {
|
||||
log.Tracef("New app version check is disabled by user")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
controlLock.Lock()
|
||||
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
|
||||
data := versionCheckJSON
|
||||
controlLock.Unlock()
|
||||
|
||||
if cached {
|
||||
// return cached copy
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(getVersionResp(data))
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Get(versionCheckURL)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||
return
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// read the body entirely
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
controlLock.Lock()
|
||||
versionCheckLastTime = now
|
||||
versionCheckJSON = body
|
||||
controlLock.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(getVersionResp(body))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy file on disk
|
||||
func copyFile(src, dst string) error {
|
||||
d, e := ioutil.ReadFile(src)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
e = ioutil.WriteFile(dst, d, 0644)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateInfo struct {
|
||||
pkgURL string // URL for the new package
|
||||
pkgName string // Full path to package file
|
||||
newVer string // New version string
|
||||
updateDir string // Full path to the directory containing unpacked files from the new package
|
||||
backupDir string // Full path to backup directory
|
||||
configName string // Full path to the current configuration file
|
||||
updateConfigName string // Full path to the configuration file to check by the new binary
|
||||
curBinName string // Full path to the current executable file
|
||||
bkpBinName string // Full path to the current executable file in backup directory
|
||||
newBinName string // Full path to the new executable file
|
||||
}
|
||||
|
||||
// Fill in updateInfo object
|
||||
func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
|
||||
var u updateInfo
|
||||
|
||||
workDir := config.ourWorkingDir
|
||||
|
||||
versionJSON := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &versionJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JSON parse: %s", err)
|
||||
}
|
||||
|
||||
u.pkgURL = versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)].(string)
|
||||
u.newVer = versionJSON["version"].(string)
|
||||
if len(u.pkgURL) == 0 || len(u.newVer) == 0 {
|
||||
return nil, fmt.Errorf("Invalid JSON")
|
||||
}
|
||||
|
||||
if u.newVer == VersionString {
|
||||
return nil, fmt.Errorf("No need to update")
|
||||
}
|
||||
|
||||
u.updateDir = filepath.Join(workDir, fmt.Sprintf("agh-update-%s", u.newVer))
|
||||
u.backupDir = filepath.Join(workDir, fmt.Sprintf("agh-backup-%s", VersionString))
|
||||
|
||||
_, pkgFileName := filepath.Split(u.pkgURL)
|
||||
if len(pkgFileName) == 0 {
|
||||
return nil, fmt.Errorf("Invalid JSON")
|
||||
}
|
||||
u.pkgName = filepath.Join(u.updateDir, pkgFileName)
|
||||
|
||||
u.configName = config.getConfigFilename()
|
||||
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml")
|
||||
if strings.HasSuffix(pkgFileName, ".zip") {
|
||||
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml")
|
||||
}
|
||||
|
||||
binName := "AdGuardHome"
|
||||
if runtime.GOOS == "windows" {
|
||||
binName = "AdGuardHome.exe"
|
||||
}
|
||||
u.curBinName = filepath.Join(workDir, binName)
|
||||
u.bkpBinName = filepath.Join(u.backupDir, binName)
|
||||
u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName)
|
||||
if strings.HasSuffix(pkgFileName, ".zip") {
|
||||
u.newBinName = filepath.Join(u.updateDir, binName)
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Unpack all files from .zip file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// Return the list of files (not directories) written
|
||||
func zipFileUnpack(zipfile, outdir string) ([]string, error) {
|
||||
|
||||
r, err := zip.OpenReader(zipfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zip.OpenReader(): %s", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
var files []string
|
||||
var err2 error
|
||||
var zr io.ReadCloser
|
||||
for _, zf := range r.File {
|
||||
zr, err = zf.Open()
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("zip file Open(): %s", err)
|
||||
break
|
||||
}
|
||||
|
||||
fi := zf.FileInfo()
|
||||
if len(fi.Name()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fn := filepath.Join(outdir, fi.Name())
|
||||
|
||||
if fi.IsDir() {
|
||||
err = os.Mkdir(fn, fi.Mode())
|
||||
if err != nil && !os.IsExist(err) {
|
||||
err2 = fmt.Errorf("os.Mkdir(): %s", err)
|
||||
break
|
||||
}
|
||||
log.Tracef("created directory %s", fn)
|
||||
continue
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("os.OpenFile(): %s", err)
|
||||
break
|
||||
}
|
||||
_, err = io.Copy(f, zr)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
err2 = fmt.Errorf("io.Copy(): %s", err)
|
||||
break
|
||||
}
|
||||
f.Close()
|
||||
|
||||
log.Tracef("created file %s", fn)
|
||||
files = append(files, fi.Name())
|
||||
}
|
||||
|
||||
zr.Close()
|
||||
return files, err2
|
||||
}
|
||||
|
||||
// Unpack all files from .tar.gz file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// Return the list of files (not directories) written
|
||||
func targzFileUnpack(tarfile, outdir string) ([]string, error) {
|
||||
|
||||
f, err := os.Open(tarfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("os.Open(): %s", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gzReader, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip.NewReader(): %s", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
var err2 error
|
||||
tarReader := tar.NewReader(gzReader)
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
err2 = nil
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("tarReader.Next(): %s", err)
|
||||
break
|
||||
}
|
||||
if len(header.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fn := filepath.Join(outdir, header.Name)
|
||||
|
||||
if header.Typeflag == tar.TypeDir {
|
||||
err = os.Mkdir(fn, os.FileMode(header.Mode&0777))
|
||||
if err != nil && !os.IsExist(err) {
|
||||
err2 = fmt.Errorf("os.Mkdir(%s): %s", fn, err)
|
||||
break
|
||||
}
|
||||
log.Tracef("created directory %s", fn)
|
||||
continue
|
||||
} else if header.Typeflag != tar.TypeReg {
|
||||
log.Tracef("%s: unknown file type %d, skipping", header.Name, header.Typeflag)
|
||||
continue
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777))
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("os.OpenFile(%s): %s", fn, err)
|
||||
break
|
||||
}
|
||||
_, err = io.Copy(f, tarReader)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
err2 = fmt.Errorf("io.Copy(): %s", err)
|
||||
break
|
||||
}
|
||||
f.Close()
|
||||
|
||||
log.Tracef("created file %s", fn)
|
||||
files = append(files, header.Name)
|
||||
}
|
||||
|
||||
gzReader.Close()
|
||||
return files, err2
|
||||
}
|
||||
|
||||
func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error {
|
||||
for _, f := range files {
|
||||
_, name := filepath.Split(f)
|
||||
if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" {
|
||||
continue
|
||||
}
|
||||
|
||||
src := filepath.Join(srcdir, f)
|
||||
if useSrcNameOnly {
|
||||
src = filepath.Join(srcdir, name)
|
||||
}
|
||||
|
||||
dst := filepath.Join(dstdir, f)
|
||||
if useDstNameOnly {
|
||||
dst = filepath.Join(dstdir, name)
|
||||
}
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("Copied: %s -> %s", src, dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Download package file and save it to disk
|
||||
func getPackageFile(u *updateInfo) error {
|
||||
resp, err := client.Get(u.pkgURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP request failed: %s", err)
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
log.Tracef("Reading HTTP body")
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ioutil.ReadAll() failed: %s", err)
|
||||
}
|
||||
|
||||
log.Tracef("Saving package to file")
|
||||
err = ioutil.WriteFile(u.pkgName, body, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ioutil.WriteFile() failed: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Perform an update procedure
|
||||
func doUpdate(u *updateInfo) error {
|
||||
log.Info("Updating from %s to %s. URL:%s Package:%s",
|
||||
VersionString, u.newVer, u.pkgURL, u.pkgName)
|
||||
|
||||
_ = os.Mkdir(u.updateDir, 0755)
|
||||
|
||||
var err error
|
||||
err = getPackageFile(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("Unpacking the package")
|
||||
_, file := filepath.Split(u.pkgName)
|
||||
var files []string
|
||||
if strings.HasSuffix(file, ".zip") {
|
||||
files, err = zipFileUnpack(u.pkgName, u.updateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("zipFileUnpack() failed: %s", err)
|
||||
}
|
||||
} else if strings.HasSuffix(file, ".tar.gz") {
|
||||
files, err = targzFileUnpack(u.pkgName, u.updateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("targzFileUnpack() failed: %s", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Unknown package extension")
|
||||
}
|
||||
|
||||
log.Tracef("Checking configuration")
|
||||
err = copyFile(u.configName, u.updateConfigName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %s", err)
|
||||
}
|
||||
cmd := exec.Command(u.newBinName, "--check-config")
|
||||
err = cmd.Run()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
log.Tracef("Backing up the current configuration")
|
||||
_ = os.Mkdir(u.backupDir, 0755)
|
||||
err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %s", err)
|
||||
}
|
||||
|
||||
// ./README.md -> backup/README.md
|
||||
err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
|
||||
config.ourWorkingDir, u.backupDir, err)
|
||||
}
|
||||
|
||||
// update/[AdGuardHome/]README.md -> ./README.md
|
||||
err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
|
||||
u.updateDir, config.ourWorkingDir, err)
|
||||
}
|
||||
|
||||
log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName)
|
||||
err = os.Rename(u.curBinName, u.bkpBinName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
// rename fails with "File in use" error
|
||||
err = copyFile(u.newBinName, u.curBinName)
|
||||
} else {
|
||||
err = os.Rename(u.newBinName, u.curBinName)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName)
|
||||
|
||||
_ = os.Remove(u.pkgName)
|
||||
_ = os.RemoveAll(u.updateDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Complete an update procedure
|
||||
func finishUpdate(u *updateInfo) {
|
||||
log.Info("Stopping all tasks")
|
||||
cleanup()
|
||||
stopHTTPServer()
|
||||
cleanupAlways()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
if config.runningAsService {
|
||||
// Note:
|
||||
// we can't restart the service via "kardianos/service" package - it kills the process first
|
||||
// we can't start a new instance - Windows doesn't allow it
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cmd := exec.Command(u.curBinName, os.Args[1:]...)
|
||||
log.Info("Restarting: %v", cmd.Args)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
|
||||
} else {
|
||||
|
||||
log.Info("Restarting: %v", os.Args)
|
||||
err := syscall.Exec(u.curBinName, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("syscall.Exec() failed: %s", err)
|
||||
}
|
||||
// Unreachable code
|
||||
}
|
||||
}
|
||||
|
||||
// Perform an update procedure to the latest available version
|
||||
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
if len(versionCheckJSON) == 0 {
|
||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := getUpdateInfo(versionCheckJSON)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = doUpdate(u)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
returnOK(w)
|
||||
|
||||
time.Sleep(time.Second) // wait (hopefully) until response is sent (not sure whether it's really necessary)
|
||||
go finishUpdate(u)
|
||||
}
|
||||
54
home/control_update_test.go
Normal file
54
home/control_update_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// +build ignore
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDoUpdate(t *testing.T) {
|
||||
config.DNS.Port = 0
|
||||
config.ourWorkingDir = "."
|
||||
u := updateInfo{
|
||||
pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz",
|
||||
pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz",
|
||||
newVer: "v0.95",
|
||||
updateDir: "./agh-update-v0.95",
|
||||
backupDir: "./agh-backup-v0.94",
|
||||
configName: "./AdGuardHome.yaml",
|
||||
updateConfigName: "./agh-update-v0.95/AdGuardHome/AdGuardHome.yaml",
|
||||
curBinName: "./AdGuardHome",
|
||||
bkpBinName: "./agh-backup-v0.94/AdGuardHome",
|
||||
newBinName: "./agh-update-v0.95/AdGuardHome/AdGuardHome",
|
||||
}
|
||||
e := doUpdate(&u)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
os.RemoveAll(u.backupDir)
|
||||
}
|
||||
|
||||
func TestTargzFileUnpack(t *testing.T) {
|
||||
fn := "./dist/AdGuardHome_v0.95_linux_amd64.tar.gz"
|
||||
outdir := "./test-unpack"
|
||||
_ = os.Mkdir(outdir, 0755)
|
||||
files, e := targzFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
os.RemoveAll(outdir)
|
||||
}
|
||||
|
||||
func TestZipFileUnpack(t *testing.T) {
|
||||
fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip"
|
||||
outdir := "./test-unpack"
|
||||
_ = os.Mkdir(outdir, 0755)
|
||||
files, e := zipFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
os.RemoveAll(outdir)
|
||||
}
|
||||
460
home/dhcp.go
Normal file
460
home/dhcp.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
var dhcpServer = dhcpd.Server{}
|
||||
|
||||
// []dhcpd.Lease -> JSON
|
||||
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
|
||||
leases := []map[string]string{}
|
||||
for _, l := range inputLeases {
|
||||
lease := map[string]string{
|
||||
"mac": l.HWAddr.String(),
|
||||
"ip": l.IP.String(),
|
||||
"hostname": l.Hostname,
|
||||
}
|
||||
|
||||
if includeExpires {
|
||||
lease["expires"] = l.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
return leases
|
||||
}
|
||||
|
||||
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
leases := convertLeases(dhcpServer.Leases(), true)
|
||||
staticLeases := convertLeases(dhcpServer.StaticLeases(), false)
|
||||
status := map[string]interface{}{
|
||||
"config": config.DHCP,
|
||||
"leases": leases,
|
||||
"static_leases": staticLeases,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type leaseJSON struct {
|
||||
HWAddr string `json:"mac"`
|
||||
IP string `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
type dhcpServerConfigJSON struct {
|
||||
dhcpd.ServerConfig `json:",inline"`
|
||||
StaticLeases []leaseJSON `json:"static_leases"`
|
||||
}
|
||||
|
||||
func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
newconfig := dhcpServerConfigJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&newconfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dhcpServer.CheckConfig(newconfig.ServerConfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dhcpServer.Stop()
|
||||
if err != nil {
|
||||
log.Error("failed to stop the DHCP server: %s", err)
|
||||
}
|
||||
|
||||
err = dhcpServer.Init(newconfig.ServerConfig)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if newconfig.Enabled {
|
||||
|
||||
staticIP, err := hasStaticIP(newconfig.InterfaceName)
|
||||
if !staticIP && err == nil {
|
||||
err = setStaticIP(newconfig.InterfaceName)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = dhcpServer.Start()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
config.DHCP = newconfig.ServerConfig
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
response := map[string]interface{}{}
|
||||
|
||||
ifaces, err := getValidNetInterfaces()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
// it's a loopback, skip it
|
||||
continue
|
||||
}
|
||||
if iface.Flags&net.FlagBroadcast == 0 {
|
||||
// this interface doesn't support broadcast, skip it
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
jsonIface := netInterface{
|
||||
Name: iface.Name,
|
||||
MTU: iface.MTU,
|
||||
HardwareAddr: iface.HardwareAddr.String(),
|
||||
}
|
||||
|
||||
if iface.Flags != 0 {
|
||||
jsonIface.Flags = iface.Flags.String()
|
||||
}
|
||||
// we don't want link-local addresses in json, so skip them
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// not an IPNet, should not happen
|
||||
httpError(w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
return
|
||||
}
|
||||
// ignore link-local
|
||||
if ipnet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
jsonIface.Addresses = append(jsonIface.Addresses, ipnet.IP.String())
|
||||
}
|
||||
if len(jsonIface.Addresses) != 0 {
|
||||
response[iface.Name] = jsonIface
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the following tasks:
|
||||
// . Search for another DHCP server running
|
||||
// . Check if a static IP is configured for the network interface
|
||||
// Respond with results
|
||||
func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Error(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
interfaceName := strings.TrimSpace(string(body))
|
||||
if interfaceName == "" {
|
||||
errorText := fmt.Sprintf("empty interface name specified")
|
||||
log.Error(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
found, err := dhcpd.CheckIfOtherDHCPServersPresent(interfaceName)
|
||||
|
||||
othSrv := map[string]interface{}{}
|
||||
foundVal := "no"
|
||||
if found {
|
||||
foundVal = "yes"
|
||||
} else if err != nil {
|
||||
foundVal = "error"
|
||||
othSrv["error"] = err.Error()
|
||||
}
|
||||
othSrv["found"] = foundVal
|
||||
|
||||
staticIP := map[string]interface{}{}
|
||||
isStaticIP, err := hasStaticIP(interfaceName)
|
||||
staticIPStatus := "yes"
|
||||
if err != nil {
|
||||
staticIPStatus = "error"
|
||||
staticIP["error"] = err.Error()
|
||||
} else if !isStaticIP {
|
||||
staticIPStatus = "no"
|
||||
staticIP["ip"] = getFullIP(interfaceName)
|
||||
}
|
||||
staticIP["static"] = staticIPStatus
|
||||
|
||||
result := map[string]interface{}{}
|
||||
result["other_server"] = othSrv
|
||||
result["static_ip"] = staticIP
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if network interface has a static IP configured
|
||||
func hasStaticIP(ifaceName string) (bool, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return false, errors.New("Can't detect static IP: not supported on Windows")
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
lines := strings.Split(string(body), "\n")
|
||||
nameLine := fmt.Sprintf("interface %s", ifaceName)
|
||||
withinInterfaceCtx := false
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if withinInterfaceCtx && len(line) == 0 {
|
||||
// an empty line resets our state
|
||||
withinInterfaceCtx = false
|
||||
}
|
||||
|
||||
if len(line) == 0 || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if !withinInterfaceCtx {
|
||||
if line == nameLine {
|
||||
// we found our interface
|
||||
withinInterfaceCtx = true
|
||||
}
|
||||
|
||||
} else {
|
||||
if strings.HasPrefix(line, "interface ") {
|
||||
// we found another interface - reset our state
|
||||
withinInterfaceCtx = false
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "static ip_address=") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get IP address with netmask
|
||||
func getFullIP(ifaceName string) string {
|
||||
cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName)
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
d, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
fields := strings.Fields(string(d))
|
||||
if len(fields) < 4 {
|
||||
return ""
|
||||
}
|
||||
_, _, err = net.ParseCIDR(fields[3])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fields[3]
|
||||
}
|
||||
|
||||
// Get gateway IP address
|
||||
func getGatewayIP(ifaceName string) string {
|
||||
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
d, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
fields := strings.Fields(string(d))
|
||||
if len(fields) < 3 || fields[0] != "default" {
|
||||
return ""
|
||||
}
|
||||
|
||||
ip := net.ParseIP(fields[2])
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fields[2]
|
||||
}
|
||||
|
||||
// Set a static IP for network interface
|
||||
func setStaticIP(ifaceName string) error {
|
||||
ip := getFullIP(ifaceName)
|
||||
if len(ip) == 0 {
|
||||
return errors.New("Can't get IP address")
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ip4, _, err := net.ParseCIDR(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
|
||||
ifaceName, ip)
|
||||
body = append(body, []byte(add)...)
|
||||
|
||||
gatewayIP := getGatewayIP(ifaceName)
|
||||
if len(gatewayIP) != 0 {
|
||||
add = fmt.Sprintf("static routers=%s\n",
|
||||
gatewayIP)
|
||||
body = append(body, []byte(add)...)
|
||||
}
|
||||
|
||||
add = fmt.Sprintf("static domain_name_servers=%s\n\n",
|
||||
ip4)
|
||||
body = append(body, []byte(add)...)
|
||||
|
||||
err = file.SafeWrite("/etc/dhcpcd.conf", body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
lj := leaseJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, _ := net.ParseMAC(lj.HWAddr)
|
||||
|
||||
lease := dhcpd.Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = dhcpServer.AddStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
log.Tracef("%s %v", r.Method, r.URL)
|
||||
|
||||
lj := leaseJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, _ := net.ParseMAC(lj.HWAddr)
|
||||
|
||||
lease := dhcpd.Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = dhcpServer.RemoveStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
returnOK(w)
|
||||
}
|
||||
|
||||
func startDHCPServer() error {
|
||||
if !config.DHCP.Enabled {
|
||||
// not enabled, don't do anything
|
||||
return nil
|
||||
}
|
||||
|
||||
err := dhcpServer.Init(config.DHCP)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't init DHCP server")
|
||||
}
|
||||
|
||||
err = dhcpServer.Start()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start DHCP server")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopDHCPServer() error {
|
||||
if !config.DHCP.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := dhcpServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
274
home/dns.go
Normal file
274
home/dns.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var dnsServer *dnsforward.Server
|
||||
|
||||
const (
|
||||
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
|
||||
)
|
||||
|
||||
type dnsContext struct {
|
||||
rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread
|
||||
// contains IP addresses of clients to be resolved by rDNS
|
||||
// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
|
||||
rdnsIP map[string]bool
|
||||
rdnsLock sync.Mutex // synchronize access to rdnsIP
|
||||
upstream upstream.Upstream // Upstream object for our own DNS server
|
||||
}
|
||||
|
||||
var dnsctx dnsContext
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer(baseDir string) {
|
||||
err := os.MkdirAll(baseDir, 0755)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
||||
}
|
||||
|
||||
dnsServer = dnsforward.NewServer(baseDir)
|
||||
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
bindhost = "127.0.0.1"
|
||||
}
|
||||
resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||
opts := upstream.Options{
|
||||
Timeout: rdnsTimeout,
|
||||
}
|
||||
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s", err)
|
||||
return
|
||||
}
|
||||
dnsctx.rdnsIP = make(map[string]bool)
|
||||
dnsctx.rdnsChannel = make(chan string, 256)
|
||||
go asyncRDNSLoop()
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return dnsServer != nil && dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func beginAsyncRDNS(ip string) {
|
||||
if clientExists(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
// add IP to rdnsIP, if not exists
|
||||
dnsctx.rdnsLock.Lock()
|
||||
defer dnsctx.rdnsLock.Unlock()
|
||||
_, ok := dnsctx.rdnsIP[ip]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
dnsctx.rdnsIP[ip] = true
|
||||
|
||||
log.Tracef("Adding %s for rDNS resolve", ip)
|
||||
select {
|
||||
case dnsctx.rdnsChannel <- ip:
|
||||
//
|
||||
default:
|
||||
log.Tracef("rDNS queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Use rDNS to get hostname by IP address
|
||||
func resolveRDNS(ip string) string {
|
||||
log.Tracef("Resolving host for %s", ip)
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
}
|
||||
var err error
|
||||
req.Question[0].Name, err = dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resp, err := dnsctx.upstream.Exchange(&req)
|
||||
if err != nil {
|
||||
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||
return ""
|
||||
}
|
||||
if len(resp.Answer) != 1 {
|
||||
log.Debug("No answer for rDNS lookup of %s", ip)
|
||||
return ""
|
||||
}
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
if !ok {
|
||||
log.Error("not a PTR response for %s", ip)
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Tracef("PTR response for %s: %s", ip, ptr.String())
|
||||
if strings.HasSuffix(ptr.Ptr, ".") {
|
||||
ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1]
|
||||
}
|
||||
|
||||
return ptr.Ptr
|
||||
}
|
||||
|
||||
// Wait for a signal and then synchronously resolve hostname by IP address
|
||||
// Add the hostname:IP pair to "Clients" array
|
||||
func asyncRDNSLoop() {
|
||||
for {
|
||||
var ip string
|
||||
ip = <-dnsctx.rdnsChannel
|
||||
|
||||
host := resolveRDNS(ip)
|
||||
if len(host) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
dnsctx.rdnsLock.Lock()
|
||||
delete(dnsctx.rdnsIP, ip)
|
||||
dnsctx.rdnsLock.Unlock()
|
||||
|
||||
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
qType := d.Req.Question[0].Qtype
|
||||
if qType != dns.TypeA && qType != dns.TypeAAAA {
|
||||
return
|
||||
}
|
||||
|
||||
ip := dnsforward.GetIPString(d.Addr)
|
||||
if ip == "" {
|
||||
// This would be quite weird if we get here
|
||||
return
|
||||
}
|
||||
|
||||
beginAsyncRDNS(ip)
|
||||
}
|
||||
|
||||
func generateServerConfig() dnsforward.ServerConfig {
|
||||
filters := []dnsfilter.Filter{}
|
||||
userFilter := userFilter()
|
||||
filters = append(filters, dnsfilter.Filter{
|
||||
ID: userFilter.ID,
|
||||
Data: userFilter.Data,
|
||||
})
|
||||
for _, filter := range config.Filters {
|
||||
filters = append(filters, dnsfilter.Filter{
|
||||
ID: filter.ID,
|
||||
Data: filter.Data,
|
||||
})
|
||||
}
|
||||
|
||||
newconfig := dnsforward.ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
|
||||
FilteringConfig: config.DNS.FilteringConfig,
|
||||
Filters: filters,
|
||||
}
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
bindhost = "127.0.0.1"
|
||||
}
|
||||
newconfig.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||
|
||||
if config.TLS.Enabled {
|
||||
newconfig.TLSConfig = config.TLS.TLSConfig
|
||||
if config.TLS.PortDNSOverTLS != 0 {
|
||||
newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(config.DNS.UpstreamDNS, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||
if err != nil {
|
||||
log.Error("Couldn't get upstreams configuration cause: %s", err)
|
||||
}
|
||||
newconfig.Upstreams = upstreamConfig.Upstreams
|
||||
newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
|
||||
newconfig.AllServers = config.DNS.AllServers
|
||||
newconfig.FilterHandler = applyClientSettings
|
||||
newconfig.OnDNSRequest = onDNSRequest
|
||||
return newconfig
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
c, ok := clientFind(clientAddr)
|
||||
if !ok || !c.UseOwnSettings {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Using settings for client with IP %s", clientAddr)
|
||||
setts.FilteringEnabled = c.FilteringEnabled
|
||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
}
|
||||
|
||||
func startDNSServer() error {
|
||||
if isRunning() {
|
||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
newconfig := generateServerConfig()
|
||||
err := dnsServer.Start(&newconfig)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
top := dnsServer.GetStatsTop()
|
||||
for k := range top.Clients {
|
||||
beginAsyncRDNS(k)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func reconfigureDNSServer() error {
|
||||
if !isRunning() {
|
||||
return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running")
|
||||
}
|
||||
|
||||
config := generateServerConfig()
|
||||
err := dnsServer.Reconfigure(&config)
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopDNSServer() error {
|
||||
if !isRunning() {
|
||||
return fmt.Errorf("Refusing to stop forwarding DNS server: not running")
|
||||
}
|
||||
|
||||
err := dnsServer.Stop()
|
||||
if err != nil {
|
||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
home/dns_test.go
Normal file
13
home/dns_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveRDNS(t *testing.T) {
|
||||
config.DNS.BindHost = "1.1.1.1"
|
||||
initDNSServer(".")
|
||||
if r := resolveRDNS("1.1.1.1"); r != "one.one.one.one" {
|
||||
t.Errorf("resolveRDNS(): %s", r)
|
||||
}
|
||||
}
|
||||
406
home/filter.go
Normal file
406
home/filter.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
var (
|
||||
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
||||
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
)
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
RulesCount int `json:"rulesCount" yaml:"-"`
|
||||
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"-"`
|
||||
checksum uint32 // checksum of the file data
|
||||
|
||||
dnsfilter.Filter `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Creates a helper object for working with the user rules
|
||||
func userFilter() filter {
|
||||
f := filter{
|
||||
// User filter always has constant ID=0
|
||||
Enabled: true,
|
||||
}
|
||||
f.Filter.Data = []byte(strings.Join(config.UserRules, "\n"))
|
||||
return f
|
||||
}
|
||||
|
||||
// Enable or disable a filter
|
||||
func filterEnable(url string, enable bool) bool {
|
||||
r := false
|
||||
config.Lock()
|
||||
for i := range config.Filters {
|
||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
||||
if filter.URL == url {
|
||||
filter.Enabled = enable
|
||||
if enable {
|
||||
e := filter.load()
|
||||
if e != nil {
|
||||
// This isn't a fatal error,
|
||||
// because it may occur when someone removes the file from disk.
|
||||
// In this case the periodic update task will try to download the file.
|
||||
filter.LastUpdated = time.Time{}
|
||||
log.Tracef("%s filter load: %v", url, e)
|
||||
}
|
||||
} else {
|
||||
filter.unload()
|
||||
}
|
||||
r = true
|
||||
break
|
||||
}
|
||||
}
|
||||
config.Unlock()
|
||||
return r
|
||||
}
|
||||
|
||||
// Return TRUE if a filter with this URL exists
|
||||
func filterExists(url string) bool {
|
||||
r := false
|
||||
config.RLock()
|
||||
for i := range config.Filters {
|
||||
if config.Filters[i].URL == url {
|
||||
r = true
|
||||
break
|
||||
}
|
||||
}
|
||||
config.RUnlock()
|
||||
return r
|
||||
}
|
||||
|
||||
// Add a filter
|
||||
// Return FALSE if a filter with this URL exists
|
||||
func filterAdd(f filter) bool {
|
||||
config.Lock()
|
||||
|
||||
// Check for duplicates
|
||||
for i := range config.Filters {
|
||||
if config.Filters[i].URL == f.URL {
|
||||
config.Unlock()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
config.Filters = append(config.Filters, f)
|
||||
config.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// Load filters from the disk
|
||||
// And if any filter has zero ID, assign a new one
|
||||
func loadFilters() {
|
||||
for i := range config.Filters {
|
||||
filter := &config.Filters[i] // otherwise we're operating on a copy
|
||||
if filter.ID == 0 {
|
||||
filter.ID = assignUniqueFilterID()
|
||||
}
|
||||
|
||||
if !filter.Enabled {
|
||||
// No need to load a filter that is not enabled
|
||||
continue
|
||||
}
|
||||
|
||||
err := filter.load()
|
||||
if err != nil {
|
||||
// This is okay for the first start, the filter will be loaded later
|
||||
log.Debug("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func deduplicateFilters() {
|
||||
// Deduplicate filters
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assignUniqueFilterID() int64 {
|
||||
value := nextFilterID
|
||||
nextFilterID++
|
||||
return value
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func periodicallyRefreshFilters() {
|
||||
for range time.Tick(time.Minute) {
|
||||
refreshFiltersIfNecessary(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Checks filters updates if necessary
|
||||
// If force is true, it ignores the filter.LastUpdated field value
|
||||
//
|
||||
// Algorithm:
|
||||
// . Get the list of filters to be updated
|
||||
// . For each filter run the download and checksum check operation
|
||||
// . If filter data hasn't changed, set new update time
|
||||
// . If filter data has changed, parse it, save it on disk, set new update time
|
||||
// . Apply changes to the current configuration
|
||||
// . Restart server
|
||||
func refreshFiltersIfNecessary(force bool) int {
|
||||
var updateFilters []filter
|
||||
|
||||
if config.firstRun {
|
||||
return 0
|
||||
}
|
||||
|
||||
config.RLock()
|
||||
for i := range config.Filters {
|
||||
f := &config.Filters[i] // otherwise we will be operating on a copy
|
||||
|
||||
if !f.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if !force && time.Since(f.LastUpdated) <= updatePeriod {
|
||||
continue
|
||||
}
|
||||
|
||||
var uf filter
|
||||
uf.ID = f.ID
|
||||
uf.URL = f.URL
|
||||
uf.Name = f.Name
|
||||
uf.checksum = f.checksum
|
||||
updateFilters = append(updateFilters, uf)
|
||||
}
|
||||
config.RUnlock()
|
||||
|
||||
updateCount := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated, err := uf.update()
|
||||
if err != nil {
|
||||
log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
|
||||
continue
|
||||
}
|
||||
if updated {
|
||||
// Saving it to the filters dir now
|
||||
err = uf.save()
|
||||
if err != nil {
|
||||
log.Printf("Failed to save the updated filter %d: %s", uf.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
} else {
|
||||
mtime := time.Now()
|
||||
e := os.Chtimes(uf.Path(), mtime, mtime)
|
||||
if e != nil {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
uf.LastUpdated = mtime
|
||||
}
|
||||
|
||||
config.Lock()
|
||||
for k := range config.Filters {
|
||||
f := &config.Filters[k]
|
||||
if f.ID != uf.ID || f.URL != uf.URL {
|
||||
continue
|
||||
}
|
||||
f.LastUpdated = uf.LastUpdated
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("Updated filter #%d. Rules: %d -> %d",
|
||||
f.ID, f.RulesCount, uf.RulesCount)
|
||||
f.Name = uf.Name
|
||||
f.Data = uf.Data
|
||||
f.RulesCount = uf.RulesCount
|
||||
f.checksum = uf.checksum
|
||||
updateCount++
|
||||
}
|
||||
config.Unlock()
|
||||
}
|
||||
|
||||
if updateCount > 0 && isRunning() {
|
||||
err := reconfigureDNSServer()
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
|
||||
panic(msg)
|
||||
}
|
||||
}
|
||||
return updateCount
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func parseFilterContents(contents []byte) (int, string) {
|
||||
lines := strings.Split(string(contents), "\n")
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
|
||||
// Count lines in the filter
|
||||
for _, line := range lines {
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] == '!' {
|
||||
m := filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
} else {
|
||||
rulesCount++
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, name
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter
|
||||
func (filter *filter) update() (bool, error) {
|
||||
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||
|
||||
resp, err := client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("content-type"))
|
||||
if !strings.HasPrefix(contentType, "text/plain") {
|
||||
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
|
||||
return false, fmt.Errorf("non-text response %s", contentType)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if the filter has been really changed
|
||||
checksum := crc32.ChecksumIEEE(body)
|
||||
if filter.checksum == checksum {
|
||||
log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Extract filter name and count number of rules
|
||||
rulesCount, filterName := parseFilterContents(body)
|
||||
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
|
||||
if filterName != "" {
|
||||
filter.Name = filterName
|
||||
}
|
||||
filter.RulesCount = rulesCount
|
||||
filter.Data = body
|
||||
filter.checksum = checksum
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// saves filter contents to the file in dataDir
|
||||
func (filter *filter) save() error {
|
||||
filterFilePath := filter.Path()
|
||||
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
err := file.SafeWrite(filterFilePath, filter.Data)
|
||||
|
||||
// update LastUpdated field after saving the file
|
||||
filter.LastUpdated = filter.LastTimeUpdated()
|
||||
return err
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (filter *filter) load() error {
|
||||
filterFilePath := filter.Path()
|
||||
log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return err
|
||||
}
|
||||
|
||||
filterFileContents, err := ioutil.ReadFile(filterFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents))
|
||||
rulesCount, _ := parseFilterContents(filterFileContents)
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.Data = filterFileContents
|
||||
filter.checksum = crc32.ChecksumIEEE(filterFileContents)
|
||||
filter.LastUpdated = filter.LastTimeUpdated()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *filter) unload() {
|
||||
filter.Data = nil
|
||||
filter.RulesCount = 0
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(config.ourWorkingDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
// LastTimeUpdated returns the time when the filter was last time updated
|
||||
func (filter *filter) LastTimeUpdated() time.Time {
|
||||
filterFilePath := filter.Path()
|
||||
s, err := os.Stat(filterFilePath)
|
||||
if os.IsNotExist(err) {
|
||||
// if the filter file does not exist, return 0001-01-01
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// if the filter file does not exist, return 0001-01-01
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// filter file modified time
|
||||
return s.ModTime()
|
||||
}
|
||||
405
home/helpers.go
Normal file
405
home/helpers.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
// ----------------------------------
|
||||
// helper functions for HTTP handlers
|
||||
// ----------------------------------
|
||||
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != method {
|
||||
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||
controlLock.Lock()
|
||||
defer controlLock.Unlock()
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("POST", handler)
|
||||
}
|
||||
|
||||
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("GET", handler)
|
||||
}
|
||||
|
||||
// Bridge between http.Handler object and Go function
|
||||
type httpHandler struct {
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.handler(w, r)
|
||||
}
|
||||
|
||||
func ensureGETHandler(handler func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||
h := httpHandler{}
|
||||
h.handler = ensure("GET", handler)
|
||||
return &h
|
||||
}
|
||||
|
||||
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if config.AuthName == "" || config.AuthPass == "" {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if !ok || user != config.AuthName || pass != config.AuthPass {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Unauthorised.\n"))
|
||||
return
|
||||
}
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
optionalAuth(a.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
return &authHandler{handler}
|
||||
}
|
||||
|
||||
// -------------------
|
||||
// first run / install
|
||||
// -------------------
|
||||
func detectFirstRun() bool {
|
||||
configfile := config.ourConfigFilename
|
||||
if !filepath.IsAbs(configfile) {
|
||||
configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename)
|
||||
}
|
||||
_, err := os.Stat(configfile)
|
||||
if !os.IsNotExist(err) {
|
||||
// do nothing, file exists
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// preInstall lets the handler run only if firstRun is true, no redirects
|
||||
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.firstRun {
|
||||
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
|
||||
type preInstallHandlerStruct struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
preInstall(p.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
// preInstallHandler returns http.Handler interface for preInstall wrapper
|
||||
func preInstallHandler(handler http.Handler) http.Handler {
|
||||
return &preInstallHandlerStruct{handler}
|
||||
}
|
||||
|
||||
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
|
||||
// it also enforces HTTPS if it is enabled and configured
|
||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if config.firstRun &&
|
||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||
r.URL.Path != "/favicon.png" {
|
||||
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
|
||||
return
|
||||
}
|
||||
// enforce https?
|
||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil {
|
||||
// yes, and we want host from host:port
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// no port in host
|
||||
host = r.Host
|
||||
}
|
||||
// construct new URL to redirect to
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
type postInstallHandlerStruct struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
postInstall(p.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
func postInstallHandler(handler http.Handler) http.Handler {
|
||||
return &postInstallHandlerStruct{handler}
|
||||
}
|
||||
|
||||
// -------------------------------------------------
|
||||
// helper functions for parsing parameters from body
|
||||
// -------------------------------------------------
|
||||
func parseParametersFromBody(r io.Reader) (map[string]string, error) {
|
||||
parameters := map[string]string{}
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if len(line) == 0 {
|
||||
// skip empty lines
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parameters, errors.New("Got invalid request body")
|
||||
}
|
||||
parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
return parameters, nil
|
||||
}
|
||||
|
||||
// ------------------
|
||||
// network interfaces
|
||||
// ------------------
|
||||
type netInterface struct {
|
||||
Name string `json:"name"`
|
||||
MTU int `json:"mtu"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Addresses []string `json:"ip_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
|
||||
// invalid interface is a ppp interface or the one that doesn't allow broadcasts
|
||||
func getValidNetInterfaces() ([]net.Interface, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
|
||||
}
|
||||
|
||||
netIfaces := []net.Interface{}
|
||||
|
||||
for i := range ifaces {
|
||||
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
|
||||
// this interface is ppp, we're not interested in this one
|
||||
continue
|
||||
}
|
||||
|
||||
iface := ifaces[i]
|
||||
netIfaces = append(netIfaces, iface)
|
||||
}
|
||||
|
||||
return netIfaces, nil
|
||||
}
|
||||
|
||||
// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only
|
||||
// we do not return link-local addresses here
|
||||
func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
||||
ifaces, err := getValidNetInterfaces()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't get interfaces")
|
||||
}
|
||||
if len(ifaces) == 0 {
|
||||
return nil, errors.New("couldn't find any legible interface")
|
||||
}
|
||||
|
||||
var netInterfaces []netInterface
|
||||
|
||||
for _, iface := range ifaces {
|
||||
addrs, e := iface.Addrs()
|
||||
if e != nil {
|
||||
return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name)
|
||||
}
|
||||
|
||||
netIface := netInterface{
|
||||
Name: iface.Name,
|
||||
MTU: iface.MTU,
|
||||
HardwareAddr: iface.HardwareAddr.String(),
|
||||
}
|
||||
|
||||
if iface.Flags != 0 {
|
||||
netIface.Flags = iface.Flags.String()
|
||||
}
|
||||
|
||||
// we don't want link-local addresses in json, so skip them
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// not an IPNet, should not happen
|
||||
return nil, fmt.Errorf("SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
}
|
||||
// ignore link-local
|
||||
if ipnet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
netIface.Addresses = append(netIface.Addresses, ipnet.IP.String())
|
||||
}
|
||||
if len(netIface.Addresses) != 0 {
|
||||
netInterfaces = append(netInterfaces, netIface)
|
||||
}
|
||||
}
|
||||
|
||||
return netInterfaces, nil
|
||||
}
|
||||
|
||||
// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily
|
||||
func checkPortAvailable(host string, port int) error {
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPacketPortAvailable(host string, port int) error {
|
||||
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
return con, err
|
||||
}
|
||||
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
bindhost = "127.0.0.1"
|
||||
}
|
||||
resolverAddr := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
||||
addrs, e := r.LookupIPAddr(ctx, host)
|
||||
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
firstErr = nil
|
||||
for _, a := range addrs {
|
||||
addr = net.JoinHostPort(a.String(), port)
|
||||
con, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return con, err
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
// check if error is "address already in use"
|
||||
func errorIsAddrInUse(err error) bool {
|
||||
errOpError, ok := err.(*net.OpError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
errSyscallError, ok := errOpError.Err.(*os.SyscallError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
errErrno, ok := errSyscallError.Err.(syscall.Errno)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
const WSAEADDRINUSE = 10048
|
||||
return errErrno == WSAEADDRINUSE
|
||||
}
|
||||
|
||||
return errErrno == syscall.EADDRINUSE
|
||||
}
|
||||
|
||||
// ---------------------
|
||||
// debug logging helpers
|
||||
// ---------------------
|
||||
func _Func() string {
|
||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||
runtime.Callers(2, pc)
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
return path.Base(f.Name())
|
||||
}
|
||||
|
||||
// Parse input string and return IPv4 address
|
||||
func parseIPv4(s string) net.IP {
|
||||
ip := net.ParseIP(s)
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v4InV6Prefix := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
|
||||
if !bytes.Equal(ip[0:12], v4InV6Prefix) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ip.To4()
|
||||
}
|
||||
25
home/helpers_test.go
Normal file
25
home/helpers_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||
ifaces, err := getValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot get net interfaces: %s", err)
|
||||
}
|
||||
if len(ifaces) == 0 {
|
||||
t.Fatalf("No net interfaces found")
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if len(iface.Addresses) == 0 {
|
||||
t.Fatalf("No addresses found for %s", iface.Name)
|
||||
}
|
||||
|
||||
log.Printf("%v", iface)
|
||||
}
|
||||
}
|
||||
542
home/home.go
Normal file
542
home/home.go
Normal file
@@ -0,0 +1,542 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gobuffalo/packr"
|
||||
)
|
||||
|
||||
// VersionString will be set through ldflags, contains current version
|
||||
var VersionString = "undefined"
|
||||
|
||||
// updateChannel can be set via ldflags
|
||||
var updateChannel = "release"
|
||||
var versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
|
||||
|
||||
const versionCheckPeriod = time.Hour * 8
|
||||
|
||||
var httpServer *http.Server
|
||||
var httpsServer struct {
|
||||
server *http.Server
|
||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||
sync.Mutex // protects config.TLS
|
||||
shutdown bool // if TRUE, don't restart the server
|
||||
}
|
||||
var pidFileName string // PID file name. Empty if no PID file was created.
|
||||
|
||||
const (
|
||||
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
||||
configSyslog = "syslog"
|
||||
)
|
||||
|
||||
// main is the entry point
|
||||
func Main() {
|
||||
// config can be specified, which reads options from there, but other command line flags have to override config values
|
||||
// therefore, we must do it manually instead of using a lib
|
||||
args := loadOptions()
|
||||
|
||||
if args.serviceControlAction != "" {
|
||||
handleServiceControlAction(args.serviceControlAction)
|
||||
return
|
||||
}
|
||||
|
||||
// run the protection
|
||||
run(args)
|
||||
}
|
||||
|
||||
// run initializes configuration and runs the AdGuard Home
|
||||
// run is a blocking method and it won't exit until the service is stopped!
|
||||
func run(args options) {
|
||||
// config file path can be overridden by command-line arguments:
|
||||
if args.configFilename != "" {
|
||||
config.ourConfigFilename = args.configFilename
|
||||
}
|
||||
|
||||
// configure working dir and config path
|
||||
initWorkingDir(args)
|
||||
|
||||
// configure log level and output
|
||||
configureLogger(args)
|
||||
|
||||
// enable TLS 1.3
|
||||
enableTLS13()
|
||||
|
||||
// print the first message after logger is configured
|
||||
log.Printf("AdGuard Home, version %s, channel %s\n", VersionString, updateChannel)
|
||||
log.Debug("Current working directory is %s", config.ourWorkingDir)
|
||||
if args.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
}
|
||||
config.runningAsService = args.runningAsService
|
||||
config.disableUpdate = args.disableUpdate
|
||||
|
||||
config.firstRun = detectFirstRun()
|
||||
if config.firstRun {
|
||||
requireAdminRights()
|
||||
}
|
||||
|
||||
signalChannel := make(chan os.Signal)
|
||||
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
go func() {
|
||||
<-signalChannel
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
clientsInit()
|
||||
|
||||
if !config.firstRun {
|
||||
// Do the upgrade if necessary
|
||||
err := upgradeConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = parseConfig()
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if args.checkConfig {
|
||||
log.Info("Configuration file is OK")
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||
config.RlimitNoFile != 0 {
|
||||
setRlimit(config.RlimitNoFile)
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
if args.bindHost != "" {
|
||||
config.BindHost = args.bindHost
|
||||
}
|
||||
if args.bindPort != 0 {
|
||||
config.BindPort = args.bindPort
|
||||
}
|
||||
|
||||
loadFilters()
|
||||
|
||||
if !config.firstRun {
|
||||
// Save the updated config
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Init the DNS server instance before registering HTTP handlers
|
||||
dnsBaseDir := filepath.Join(config.ourWorkingDir, dataDir)
|
||||
initDNSServer(dnsBaseDir)
|
||||
|
||||
if !config.firstRun {
|
||||
err := startDNSServer()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = startDHCPServer()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
pidFileName = args.pidFile
|
||||
}
|
||||
|
||||
// Update filters we've just loaded right away, don't wait for periodic update timer
|
||||
go func() {
|
||||
refreshFiltersIfNecessary(false)
|
||||
}()
|
||||
// Schedule automatic filters updates
|
||||
go periodicallyRefreshFilters()
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../build/static")
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||
registerControlHandlers()
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if config.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||
registerInstallHandlers()
|
||||
}
|
||||
|
||||
httpsServer.cond = sync.NewCond(&httpsServer.Mutex)
|
||||
|
||||
// for https, we have a separate goroutine loop
|
||||
go httpServerLoop()
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
httpServer = &http.Server{
|
||||
Addr: address,
|
||||
}
|
||||
err := httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
|
||||
}
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
func httpServerLoop() {
|
||||
for !httpsServer.shutdown {
|
||||
httpsServer.cond.L.Lock()
|
||||
// this mechanism doesn't let us through until all conditions are met
|
||||
for config.TLS.Enabled == false ||
|
||||
config.TLS.PortHTTPS == 0 ||
|
||||
config.TLS.PrivateKey == "" ||
|
||||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
|
||||
httpsServer.cond.Wait()
|
||||
}
|
||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(config.TLS.CertificateChain, config.TLS.PrivateKey, config.TLS.ServerName)
|
||||
if !data.ValidPair {
|
||||
cleanupAlways()
|
||||
log.Fatal(data.WarningValidation)
|
||||
}
|
||||
config.Lock()
|
||||
config.TLS.tlsConfigStatus = data // update warnings
|
||||
config.Unlock()
|
||||
|
||||
// prepare certs for HTTPS server
|
||||
// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
|
||||
certchain := make([]byte, len(config.TLS.CertificateChain))
|
||||
copy(certchain, []byte(config.TLS.CertificateChain))
|
||||
privatekey := make([]byte, len(config.TLS.PrivateKey))
|
||||
copy(privatekey, []byte(config.TLS.PrivateKey))
|
||||
cert, err := tls.X509KeyPair(certchain, privatekey)
|
||||
if err != nil {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
httpsServer.server = &http.Server{
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
printHTTPAddresses("https")
|
||||
err = httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the current user has root (administrator) rights
|
||||
// and if not, ask and try to run as root
|
||||
func requireAdminRights() {
|
||||
admin, _ := haveAdminRights()
|
||||
if admin {
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
|
||||
|
||||
} else {
|
||||
log.Error("This is the first launch of AdGuard Home. You must run it as root.")
|
||||
|
||||
_, _ = io.WriteString(os.Stdout, "Do you want to start AdGuard Home as root user? [y/n] ")
|
||||
stdin := bufio.NewReader(os.Stdin)
|
||||
buf, _ := stdin.ReadString('\n')
|
||||
buf = strings.TrimSpace(buf)
|
||||
if buf != "y" {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmd := exec.Command("sudo", os.Args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
_ = cmd.Run()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Write PID to a file
|
||||
func writePIDFile(fn string) bool {
|
||||
data := fmt.Sprintf("%d", os.Getpid())
|
||||
err := ioutil.WriteFile(fn, []byte(data), 0644)
|
||||
if err != nil {
|
||||
log.Error("Couldn't write PID to file %s: %v", fn, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// initWorkingDir initializes the ourWorkingDir
|
||||
// if no command-line arguments specified, we use the directory where our binary file is located
|
||||
func initWorkingDir(args options) {
|
||||
exec, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if args.workDir != "" {
|
||||
// If there is a custom config file, use it's directory as our working dir
|
||||
config.ourWorkingDir = args.workDir
|
||||
} else {
|
||||
config.ourWorkingDir = filepath.Dir(exec)
|
||||
}
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output
|
||||
func configureLogger(args options) {
|
||||
ls := getLogSettings()
|
||||
|
||||
// command-line arguments can override config settings
|
||||
if args.verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.LogFile = args.logFile
|
||||
}
|
||||
|
||||
level := log.INFO
|
||||
if ls.Verbose {
|
||||
level = log.DEBUG
|
||||
}
|
||||
log.SetLevel(level)
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing else is configured
|
||||
// Otherwise, we'll simply loose the log output
|
||||
ls.LogFile = configSyslog
|
||||
}
|
||||
|
||||
if ls.LogFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ls.LogFile == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows
|
||||
err := configureSyslog()
|
||||
if err != nil {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile)
|
||||
if filepath.IsAbs(ls.LogFile) {
|
||||
logFilePath = ls.LogFile
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot create a log file: %s", err)
|
||||
}
|
||||
log.SetOutput(file)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO after GO 1.13 release TLS 1.3 will be enabled by default. Remove this afterward
|
||||
func enableTLS13() {
|
||||
err := os.Setenv("GODEBUG", os.Getenv("GODEBUG")+",tls13=1")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to enable TLS 1.3: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
log.Info("Stopping AdGuard Home")
|
||||
|
||||
err := stopDNSServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DNS server: %s", err)
|
||||
}
|
||||
err = stopDHCPServer()
|
||||
if err != nil {
|
||||
log.Error("Couldn't stop DHCP server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func stopHTTPServer() {
|
||||
httpsServer.shutdown = true
|
||||
if httpsServer.server != nil {
|
||||
httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
httpServer.Shutdown(context.TODO())
|
||||
}
|
||||
|
||||
// This function is called before application exits
|
||||
func cleanupAlways() {
|
||||
if len(pidFileName) != 0 {
|
||||
os.Remove(pidFileName)
|
||||
}
|
||||
log.Info("Stopped")
|
||||
}
|
||||
|
||||
// command-line arguments
|
||||
type options struct {
|
||||
verbose bool // is verbose logging enabled
|
||||
configFilename string // path to the config file
|
||||
workDir string // path to the working directory where we will store the filters data and the querylog
|
||||
bindHost string // host address to bind HTTP server on
|
||||
bindPort int // port to serve HTTP pages on
|
||||
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
pidFile string // File name to save PID to
|
||||
checkConfig bool // Check configuration and exit
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
|
||||
// service control action (see service.ControlAction array + "status" command)
|
||||
serviceControlAction string
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
}
|
||||
|
||||
// loadOptions reads command line arguments and initializes configuration
|
||||
func loadOptions() options {
|
||||
o := options{}
|
||||
|
||||
var printHelp func()
|
||||
var opts = []struct {
|
||||
longName string
|
||||
shortName string
|
||||
description string
|
||||
callbackWithValue func(value string)
|
||||
callbackNoValue func()
|
||||
}{
|
||||
{"config", "c", "Path to the config file", func(value string) { o.configFilename = value }, nil},
|
||||
{"work-dir", "w", "Path to the working directory", func(value string) { o.workDir = value }, nil},
|
||||
{"host", "h", "Host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
|
||||
{"port", "p", "Port to serve HTTP pages on", func(value string) {
|
||||
v, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
panic("Got port that is not a number")
|
||||
}
|
||||
o.bindPort = v
|
||||
}, nil},
|
||||
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) {
|
||||
o.serviceControlAction = value
|
||||
}, nil},
|
||||
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
|
||||
o.logFile = value
|
||||
}, nil},
|
||||
{"pidfile", "", "Path to a file where PID is stored", func(value string) { o.pidFile = value }, nil},
|
||||
{"check-config", "", "Check configuration and exit", nil, func() { o.checkConfig = true }},
|
||||
{"no-check-update", "", "Don't check for updates", nil, func() { o.disableUpdate = true }},
|
||||
{"verbose", "v", "Enable verbose output", nil, func() { o.verbose = true }},
|
||||
{"help", "", "Print this help", nil, func() {
|
||||
printHelp()
|
||||
os.Exit(64)
|
||||
}},
|
||||
}
|
||||
printHelp = func() {
|
||||
fmt.Printf("Usage:\n\n")
|
||||
fmt.Printf("%s [options]\n\n", os.Args[0])
|
||||
fmt.Printf("Options:\n")
|
||||
for _, opt := range opts {
|
||||
val := ""
|
||||
if opt.callbackWithValue != nil {
|
||||
val = " VALUE"
|
||||
}
|
||||
if opt.shortName != "" {
|
||||
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName+val, opt.description)
|
||||
} else {
|
||||
fmt.Printf(" %-34s %s\n", "--"+opt.longName+val, opt.description)
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(os.Args); i++ {
|
||||
v := os.Args[i]
|
||||
knownParam := false
|
||||
for _, opt := range opts {
|
||||
if v == "--"+opt.longName || (opt.shortName != "" && v == "-"+opt.shortName) {
|
||||
if opt.callbackWithValue != nil {
|
||||
if i+1 >= len(os.Args) {
|
||||
log.Error("Got %s without argument\n", v)
|
||||
os.Exit(64)
|
||||
}
|
||||
i++
|
||||
opt.callbackWithValue(os.Args[i])
|
||||
} else if opt.callbackNoValue != nil {
|
||||
opt.callbackNoValue()
|
||||
}
|
||||
knownParam = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !knownParam {
|
||||
log.Error("unknown option %v\n", v)
|
||||
printHelp()
|
||||
os.Exit(64)
|
||||
}
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
|
||||
// prints IP addresses which user can use to open the admin interface
|
||||
// proto is either "http" or "https"
|
||||
func printHTTPAddresses(proto string) {
|
||||
var address string
|
||||
|
||||
if proto == "https" && config.TLS.ServerName != "" {
|
||||
if config.TLS.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", config.TLS.ServerName)
|
||||
} else {
|
||||
log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS)
|
||||
}
|
||||
} else if config.BindHost == "0.0.0.0" {
|
||||
log.Println("AdGuard Home is available on the following addresses:")
|
||||
ifaces, err := getValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
// That's weird, but we'll ignore it
|
||||
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
address = net.JoinHostPort(iface.Addresses[0], strconv.Itoa(config.BindPort))
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
} else {
|
||||
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
}
|
||||
}
|
||||
72
home/i18n.go
Normal file
72
home/i18n.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// --------------------
|
||||
// internationalization
|
||||
// --------------------
|
||||
var allowedLanguages = map[string]bool{
|
||||
"en": true,
|
||||
"ru": true,
|
||||
"vi": true,
|
||||
"es": true,
|
||||
"fr": true,
|
||||
"ja": true,
|
||||
"sv": true,
|
||||
"pt-br": true,
|
||||
"zh-tw": true,
|
||||
"bg": true,
|
||||
"zh-cn": true,
|
||||
}
|
||||
|
||||
func isLanguageAllowed(language string) bool {
|
||||
l := strings.ToLower(language)
|
||||
return allowedLanguages[l]
|
||||
}
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
language := strings.TrimSpace(string(body))
|
||||
if language == "" {
|
||||
errorText := fmt.Sprintf("empty language specified")
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !isLanguageAllowed(language) {
|
||||
errorText := fmt.Sprintf("unknown language specified: %s", language)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
config.Language = language
|
||||
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
27
home/os_unix.go
Normal file
27
home/os_unix.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Set user-specified limit of how many fd's we can use
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/659
|
||||
func setRlimit(val uint) {
|
||||
var rlim syscall.Rlimit
|
||||
rlim.Max = uint64(val)
|
||||
rlim.Cur = uint64(val)
|
||||
err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlim)
|
||||
if err != nil {
|
||||
log.Error("Setrlimit() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the current user has root (administrator) rights
|
||||
func haveAdminRights() (bool, error) {
|
||||
return os.Getuid() == 0, nil
|
||||
}
|
||||
28
home/os_windows.go
Normal file
28
home/os_windows.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package home
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// Set user-specified limit of how many fd's we can use
|
||||
func setRlimit(val uint) {
|
||||
}
|
||||
|
||||
func haveAdminRights() (bool, error) {
|
||||
var token windows.Token
|
||||
h, _ := windows.GetCurrentProcess()
|
||||
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
info := make([]byte, 4)
|
||||
var returnedLen uint32
|
||||
err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen)
|
||||
token.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if info[0] == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
185
home/service.go
Normal file
185
home/service.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
|
||||
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
|
||||
serviceName = "AdGuardHome"
|
||||
serviceDisplayName = "AdGuard Home service"
|
||||
serviceDescription = "AdGuard Home: Network-level blocker"
|
||||
)
|
||||
|
||||
// Represents the program that will be launched by a service or daemon
|
||||
type program struct {
|
||||
}
|
||||
|
||||
// Start should quickly start the program
|
||||
func (p *program) Start(s service.Service) error {
|
||||
// Start should not block. Do the actual work async.
|
||||
args := options{runningAsService: true}
|
||||
go run(args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the program
|
||||
func (p *program) Stop(s service.Service) error {
|
||||
// Stop should not block. Return with a few seconds.
|
||||
cleanup()
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleServiceControlAction one of the possible control actions:
|
||||
// install -- installs a service/daemon
|
||||
// uninstall -- uninstalls it
|
||||
// status -- prints the service status
|
||||
// start -- starts the previously installed service
|
||||
// stop -- stops the previously installed service
|
||||
// restart - restarts the previously installed service
|
||||
// run - this is a special command that is not supposed to be used directly
|
||||
// it is specified when we register a service, and it indicates to the app
|
||||
// that it is being run as a service/daemon.
|
||||
func handleServiceControlAction(action string) {
|
||||
log.Printf("Service control action: %s", action)
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatal("Unable to find the path to the current directory")
|
||||
}
|
||||
svcConfig := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
WorkingDirectory: pwd,
|
||||
Arguments: []string{"-s", "run"},
|
||||
}
|
||||
configureService(svcConfig)
|
||||
prg := &program{}
|
||||
s, err := service.New(prg, svcConfig)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if action == "status" {
|
||||
status, errSt := s.Status()
|
||||
if errSt != nil {
|
||||
log.Fatalf("failed to get service status: %s", errSt)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
log.Printf("Service status is unknown")
|
||||
case service.StatusStopped:
|
||||
log.Printf("Service is stopped")
|
||||
case service.StatusRunning:
|
||||
log.Printf("Service is running")
|
||||
}
|
||||
} else if action == "run" {
|
||||
err = s.Run()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to run service: %s", err)
|
||||
}
|
||||
} else {
|
||||
if action == "uninstall" {
|
||||
// In case of Windows and Linux when a running service is being uninstalled,
|
||||
// it is just marked for deletion but not stopped
|
||||
// So we explicitly stop it here
|
||||
_ = s.Stop()
|
||||
}
|
||||
|
||||
err = service.Control(s, action)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Action %s has been done successfully on %s", action, service.ChosenSystem().String())
|
||||
|
||||
if action == "install" {
|
||||
// Start automatically after install
|
||||
err = service.Control(s, "start")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start the service: %s", err)
|
||||
}
|
||||
log.Printf("Service has been started")
|
||||
|
||||
if detectFirstRun() {
|
||||
log.Printf(`Almost ready!
|
||||
AdGuard Home is successfully installed and will automatically start on boot.
|
||||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.`)
|
||||
printHTTPAddresses("http")
|
||||
}
|
||||
|
||||
} else if action == "uninstall" {
|
||||
cleanupService()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// configureService defines additional settings of the service
|
||||
func configureService(c *service.Config) {
|
||||
c.Option = service.KeyValue{}
|
||||
|
||||
// OS X
|
||||
// Redefines the launchd config file template
|
||||
// The purpose is to enable stdout/stderr redirect by default
|
||||
c.Option["LaunchdConfig"] = launchdConfig
|
||||
// This key is used to start the job as soon as it has been loaded. For daemons this means execution at boot time, for agents execution at login.
|
||||
c.Option["RunAtLoad"] = true
|
||||
|
||||
// POSIX
|
||||
// Redirect StdErr & StdOut to files.
|
||||
c.Option["LogOutput"] = true
|
||||
}
|
||||
|
||||
// cleanupService called on the service uninstall, cleans up additional files if needed
|
||||
func cleanupService() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
// Removing log files on cleanup and ignore errors
|
||||
err := os.Remove(launchdStdoutPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("cannot remove %s", launchdStdoutPath)
|
||||
}
|
||||
err = os.Remove(launchdStderrPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("cannot remove %s", launchdStderrPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Basically the same template as the one defined in github.com/kardianos/service
|
||||
// but with two additional keys - StandardOutPath and StandardErrorPath
|
||||
var launchdConfig = `<?xml version='1.0' encoding='UTF-8'?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple Computer//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd" >
|
||||
<plist version='1.0'>
|
||||
<dict>
|
||||
<key>Label</key><string>{{html .Name}}</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{{html .Path}}</string>
|
||||
{{range .Config.Arguments}}
|
||||
<string>{{html .}}</string>
|
||||
{{end}}
|
||||
</array>
|
||||
{{if .UserName}}<key>UserName</key><string>{{html .UserName}}</string>{{end}}
|
||||
{{if .ChRoot}}<key>RootDirectory</key><string>{{html .ChRoot}}</string>{{end}}
|
||||
{{if .WorkingDirectory}}<key>WorkingDirectory</key><string>{{html .WorkingDirectory}}</string>{{end}}
|
||||
<key>SessionCreate</key><{{bool .SessionCreate}}/>
|
||||
<key>KeepAlive</key><{{bool .KeepAlive}}/>
|
||||
<key>RunAtLoad</key><{{bool .RunAtLoad}}/>
|
||||
<key>Disabled</key><false/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>` + launchdStdoutPath + `</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>` + launchdStderrPath + `</string>
|
||||
</dict>
|
||||
</plist>
|
||||
`
|
||||
18
home/syslog_others.go
Normal file
18
home/syslog_others.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// +build !windows,!nacl,!plan9
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/syslog"
|
||||
)
|
||||
|
||||
// configureSyslog reroutes standard logger output to syslog
|
||||
func configureSyslog() error {
|
||||
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.SetOutput(w)
|
||||
return nil
|
||||
}
|
||||
39
home/syslog_windows.go
Normal file
39
home/syslog_windows.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/eventlog"
|
||||
)
|
||||
|
||||
type eventLogWriter struct {
|
||||
el *eventlog.Log
|
||||
}
|
||||
|
||||
// Write sends a log message to the Event Log.
|
||||
func (w *eventLogWriter) Write(b []byte) (int, error) {
|
||||
return len(b), w.el.Info(1, string(b))
|
||||
}
|
||||
|
||||
func configureSyslog() error {
|
||||
// Note that the eventlog src is the same as the service name
|
||||
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
|
||||
|
||||
// Continue if we receive "registry key already exists" or if we get
|
||||
// ERROR_ACCESS_DENIED so that we can log without administrative permissions
|
||||
// for pre-existing eventlog sources.
|
||||
if err := eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error); err != nil {
|
||||
if !strings.Contains(err.Error(), "registry key already exists") && err != windows.ERROR_ACCESS_DENIED {
|
||||
return err
|
||||
}
|
||||
}
|
||||
el, err := eventlog.Open(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.SetOutput(&eventLogWriter{el: el})
|
||||
return nil
|
||||
}
|
||||
194
home/upgrade.go
Normal file
194
home/upgrade.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const currentSchemaVersion = 3 // used for upgrading from old configs to new config
|
||||
|
||||
// Performs necessary upgrade operations if needed
|
||||
func upgradeConfig() error {
|
||||
// read a config file into an interface map, so we can manipulate values without losing any
|
||||
diskConfig := map[string]interface{}{}
|
||||
body, err := readConfigFile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(body, &diskConfig)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't parse config file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
schemaVersionInterface, ok := diskConfig["schema_version"]
|
||||
log.Tracef("got schema version %v", schemaVersionInterface)
|
||||
if !ok {
|
||||
// no schema version, set it to 0
|
||||
schemaVersionInterface = 0
|
||||
}
|
||||
|
||||
schemaVersion, ok := schemaVersionInterface.(int)
|
||||
if !ok {
|
||||
err = fmt.Errorf("configuration file contains non-integer schema_version, abort")
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
// do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
return upgradeConfigSchema(schemaVersion, &diskConfig)
|
||||
}
|
||||
|
||||
// Upgrade from oldVersion to newVersion
|
||||
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
|
||||
switch oldVersion {
|
||||
case 0:
|
||||
err := upgradeSchema0to3(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case 1:
|
||||
err := upgradeSchema1to3(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case 2:
|
||||
err := upgradeSchema2to3(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
err := fmt.Errorf("configuration file contains unknown schema_version, abort")
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
body, err := yaml.Marshal(diskConfig)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't generate YAML file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
config.fileData = body
|
||||
err = file.SafeWrite(configFile, body)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't save YAML config: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// The first schema upgrade:
|
||||
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", _Func())
|
||||
|
||||
dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt")
|
||||
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
|
||||
err = os.Remove(dnsFilterPath)
|
||||
if err != nil {
|
||||
log.Printf("Cannot remove %s due to %s", dnsFilterPath, err)
|
||||
// not fatal, move on
|
||||
}
|
||||
}
|
||||
|
||||
(*diskConfig)["schema_version"] = 1
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Second schema upgrade:
|
||||
// coredns is now dns in config
|
||||
// delete 'Corefile', since we don't use that anymore
|
||||
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", _Func())
|
||||
|
||||
coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile")
|
||||
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
|
||||
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
|
||||
err = os.Remove(coreFilePath)
|
||||
if err != nil {
|
||||
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
|
||||
// not fatal, move on
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := (*diskConfig)["dns"]; !ok {
|
||||
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
|
||||
delete((*diskConfig), "coredns")
|
||||
}
|
||||
(*diskConfig)["schema_version"] = 2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Third schema upgrade:
|
||||
// Bootstrap DNS becomes an array
|
||||
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
|
||||
log.Printf("%s(): called", _Func())
|
||||
|
||||
// Let's read dns configuration from diskConfig
|
||||
dnsConfig, ok := (*diskConfig)["dns"]
|
||||
if !ok {
|
||||
return fmt.Errorf("no DNS configuration in config file")
|
||||
}
|
||||
|
||||
// Convert interface{} to map[string]interface{}
|
||||
newDNSConfig := make(map[string]interface{})
|
||||
|
||||
switch v := dnsConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
for k, v := range v {
|
||||
newDNSConfig[fmt.Sprint(k)] = v
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("DNS configuration is not a map")
|
||||
}
|
||||
|
||||
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
|
||||
if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok {
|
||||
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
|
||||
(newDNSConfig)["bootstrap_dns"] = newBootstrapConfig
|
||||
(*diskConfig)["dns"] = newDNSConfig
|
||||
} else {
|
||||
return fmt.Errorf("no bootstrap DNS in DNS config")
|
||||
}
|
||||
|
||||
// Bump schema version
|
||||
(*diskConfig)["schema_version"] = 3
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// jump three schemas at once -- this time we just do it sequentially
|
||||
func upgradeSchema0to3(diskConfig *map[string]interface{}) error {
|
||||
err := upgradeSchema0to1(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return upgradeSchema1to3(diskConfig)
|
||||
}
|
||||
|
||||
// jump two schemas at once -- this time we just do it sequentially
|
||||
func upgradeSchema1to3(diskConfig *map[string]interface{}) error {
|
||||
err := upgradeSchema1to2(diskConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return upgradeSchema2to3(diskConfig)
|
||||
}
|
||||
230
home/upgrade_test.go
Normal file
230
home/upgrade_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgrade1to2(t *testing.T) {
|
||||
// let's create test config for 1 schema version
|
||||
diskConfig := createTestDiskConfig(1)
|
||||
|
||||
// update config
|
||||
err := upgradeSchema1to2(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// ensure that schema version was bumped
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 2)
|
||||
|
||||
// old coredns entry should be removed
|
||||
_, ok := diskConfig["coredns"]
|
||||
if ok {
|
||||
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// pull out new dns config
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
|
||||
}
|
||||
|
||||
// cast dns configurations to maps and compare them
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
|
||||
|
||||
// exclude dns config and schema version from disk config comparison
|
||||
oldExcludedEntries := []string{"coredns", "schema_version"}
|
||||
newExcludedEntries := []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(1)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
|
||||
}
|
||||
|
||||
func TestUpgrade2to3(t *testing.T) {
|
||||
// let's create test config
|
||||
diskConfig := createTestDiskConfig(2)
|
||||
|
||||
// upgrade schema from 2 to 3
|
||||
err := upgradeSchema2to3(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
|
||||
}
|
||||
|
||||
// check new schema version
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 3)
|
||||
|
||||
// pull out new dns configuration
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No dns config in new configuration")
|
||||
}
|
||||
|
||||
// cast dns configuration to map
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
|
||||
// check if bootstrap DNS becomes an array
|
||||
bootstrapDNS := newDNSConfig["bootstrap_dns"]
|
||||
switch v := bootstrapDNS.(type) {
|
||||
case []string:
|
||||
if len(v) != 1 {
|
||||
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
|
||||
}
|
||||
|
||||
if v[0] != "8.8.8.8:53" {
|
||||
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
|
||||
}
|
||||
|
||||
// exclude bootstrap DNS from DNS configs comparison
|
||||
excludedEntries := []string{"bootstrap_dns"}
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
|
||||
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
|
||||
|
||||
// excluded dns config and schema version from disk config comparison
|
||||
excludedEntries = []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(2)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
|
||||
}
|
||||
|
||||
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
|
||||
newConfig = make(map[string]interface{})
|
||||
switch v := oldConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
for key, value := range v {
|
||||
newConfig[fmt.Sprint(key)] = value
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
newConfig[key] = value
|
||||
}
|
||||
default:
|
||||
t.Fatalf("DNS configuration is not a map")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
|
||||
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
|
||||
for _, k := range oldKey {
|
||||
delete(*oldConfig, k)
|
||||
}
|
||||
for _, k := range newKey {
|
||||
delete(*newConfig, k)
|
||||
}
|
||||
compareConfigs(t, oldConfig, newConfig)
|
||||
}
|
||||
|
||||
// compares configs before and after schema upgrade
|
||||
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
|
||||
if len(*oldConfig) != len(*newConfig) {
|
||||
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
|
||||
}
|
||||
|
||||
// Check old and new entries
|
||||
for k, v := range *newConfig {
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case int:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case []string:
|
||||
for i, line := range value {
|
||||
if len((*oldConfig)[k].([]string)) != len(value) {
|
||||
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
|
||||
}
|
||||
if (*oldConfig)[k].([]string)[i] != line {
|
||||
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
|
||||
}
|
||||
}
|
||||
case bool:
|
||||
if v != (*oldConfig)[k].(bool) {
|
||||
t.Fatalf("wrong boolean value for %s", k)
|
||||
}
|
||||
case []filter:
|
||||
if len((*oldConfig)[k].([]filter)) != len(value) {
|
||||
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
|
||||
}
|
||||
for i, newFilter := range value {
|
||||
oldFilter := (*oldConfig)[k].([]filter)[i]
|
||||
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
|
||||
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
|
||||
}
|
||||
}
|
||||
default:
|
||||
t.Fatalf("uknown data type for %s: %T", k, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
|
||||
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
|
||||
switch v := newSchemaVersion.(type) {
|
||||
case int:
|
||||
if v != schemaVersion {
|
||||
t.Fatalf("Wrong schema version in new config file")
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Schema version is not an integer after update")
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
|
||||
diskConfig = make(map[string]interface{})
|
||||
diskConfig["language"] = "en"
|
||||
diskConfig["filters"] = []filter{
|
||||
{
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
RulesCount: 100,
|
||||
},
|
||||
{
|
||||
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
|
||||
Name: "Germany filter",
|
||||
RulesCount: 200,
|
||||
},
|
||||
}
|
||||
diskConfig["user_rules"] = []string{}
|
||||
diskConfig["schema_version"] = schemaVersion
|
||||
diskConfig["bind_host"] = "0.0.0.0"
|
||||
diskConfig["bind_port"] = 80
|
||||
diskConfig["auth_name"] = "name"
|
||||
diskConfig["auth_pass"] = "pass"
|
||||
dnsConfig := createTestDNSConfig(schemaVersion)
|
||||
if schemaVersion > 1 {
|
||||
diskConfig["dns"] = dnsConfig
|
||||
} else {
|
||||
diskConfig["coredns"] = dnsConfig
|
||||
}
|
||||
return diskConfig
|
||||
}
|
||||
|
||||
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
|
||||
dnsConfig := make(map[interface{}]interface{})
|
||||
dnsConfig["port"] = 53
|
||||
dnsConfig["blocked_response_ttl"] = 10
|
||||
dnsConfig["querylog_enabled"] = true
|
||||
dnsConfig["ratelimit"] = 20
|
||||
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
|
||||
if schemaVersion > 2 {
|
||||
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
|
||||
}
|
||||
dnsConfig["parental_sensitivity"] = 13
|
||||
dnsConfig["ratelimit_whitelist"] = []string{}
|
||||
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
|
||||
dnsConfig["filtering_enabled"] = true
|
||||
dnsConfig["refuse_any"] = true
|
||||
dnsConfig["parental_enabled"] = true
|
||||
dnsConfig["bind_host"] = "0.0.0.0"
|
||||
dnsConfig["protection_enabled"] = true
|
||||
dnsConfig["safesearch_enabled"] = true
|
||||
dnsConfig["safebrowsing_enabled"] = true
|
||||
return dnsConfig
|
||||
}
|
||||
Reference in New Issue
Block a user