* move ./*.go files into ./home/ directory

This commit is contained in:
Simon Zolin
2019-06-10 11:33:19 +03:00
parent 9fe34818e3
commit dc682763ff
27 changed files with 37 additions and 28 deletions

482
home/clients.go Normal file
View File

@@ -0,0 +1,482 @@
package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"runtime"
"strings"
"sync"
"github.com/AdguardTeam/golibs/log"
)
// Client information
type Client struct {
IP string
MAC string
Name string
UseOwnSettings bool // false: use global settings
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
}
type clientJSON struct {
IP string `json:"ip"`
MAC string `json:"mac"`
Name string `json:"name"`
UseGlobalSettings bool `json:"use_global_settings"`
FilteringEnabled bool `json:"filtering_enabled"`
ParentalEnabled bool `json:"parental_enabled"`
SafeSearchEnabled bool `json:"safebrowsing_enabled"`
SafeBrowsingEnabled bool `json:"safesearch_enabled"`
}
type clientSource uint
const (
ClientSourceHostsFile clientSource = 0 // from /etc/hosts
ClientSourceRDNS clientSource = 1 // from rDNS
)
// ClientHost information
type ClientHost struct {
Host string
Source clientSource
}
type clientsContainer struct {
list map[string]*Client
ipIndex map[string]*Client
ipHost map[string]ClientHost // IP -> Hostname
lock sync.Mutex
}
var clients clientsContainer
// Initialize clients container
func clientsInit() {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
clients.list = make(map[string]*Client)
clients.ipIndex = make(map[string]*Client)
clients.ipHost = make(map[string]ClientHost)
clientsAddFromHostsFile()
}
func clientsGetList() map[string]*Client {
return clients.list
}
func clientExists(ip string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.ipIndex[ip]
if ok {
return true
}
_, ok = clients.ipHost[ip]
return ok
}
// Search for a client by IP
func clientFind(ip string) (Client, bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.ipIndex[ip]
if ok {
return *c, true
}
for _, c = range clients.list {
if len(c.MAC) != 0 {
mac, err := net.ParseMAC(c.MAC)
if err != nil {
continue
}
ipAddr := dhcpServer.FindIPbyMAC(mac)
if ipAddr == nil {
continue
}
if ip == ipAddr.String() {
return *c, true
}
}
}
return Client{}, false
}
// Check if Client object's fields are correct
func clientCheck(c *Client) error {
if len(c.Name) == 0 {
return fmt.Errorf("Invalid Name")
}
if (len(c.IP) == 0 && len(c.MAC) == 0) ||
(len(c.IP) != 0 && len(c.MAC) != 0) {
return fmt.Errorf("IP or MAC required")
}
if len(c.IP) != 0 {
ip := net.ParseIP(c.IP)
if ip == nil {
return fmt.Errorf("Invalid IP")
}
c.IP = ip.String()
} else {
_, err := net.ParseMAC(c.MAC)
if err != nil {
return fmt.Errorf("Invalid MAC: %s", err)
}
}
return nil
}
// Add a new client object
// Return true: success; false: client exists.
func clientAdd(c Client) (bool, error) {
e := clientCheck(&c)
if e != nil {
return false, e
}
clients.lock.Lock()
defer clients.lock.Unlock()
// check Name index
_, ok := clients.list[c.Name]
if ok {
return false, nil
}
// check IP index
if len(c.IP) != 0 {
c2, ok := clients.ipIndex[c.IP]
if ok {
return false, fmt.Errorf("Another client uses the same IP address: %s", c2.Name)
}
}
clients.list[c.Name] = &c
if len(c.IP) != 0 {
clients.ipIndex[c.IP] = &c
}
log.Tracef("'%s': '%s' | '%s' -> [%d]", c.Name, c.IP, c.MAC, len(clients.list))
return true, nil
}
// Remove a client
func clientDel(name string) bool {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.list[name]
if !ok {
return false
}
delete(clients.list, name)
delete(clients.ipIndex, c.IP)
return true
}
// Update a client
func clientUpdate(name string, c Client) error {
err := clientCheck(&c)
if err != nil {
return err
}
clients.lock.Lock()
defer clients.lock.Unlock()
old, ok := clients.list[name]
if !ok {
return fmt.Errorf("Client not found")
}
// check Name index
if old.Name != c.Name {
_, ok = clients.list[c.Name]
if ok {
return fmt.Errorf("Client already exists")
}
}
// check IP index
if old.IP != c.IP && len(c.IP) != 0 {
c2, ok := clients.ipIndex[c.IP]
if ok {
return fmt.Errorf("Another client uses the same IP address: %s", c2.Name)
}
}
// update Name index
if old.Name != c.Name {
delete(clients.list, old.Name)
}
clients.list[c.Name] = &c
// update IP index
if old.IP != c.IP {
delete(clients.ipIndex, old.IP)
}
if len(c.IP) != 0 {
clients.ipIndex[c.IP] = &c
}
return nil
}
func clientAddHost(ip, host string, source clientSource) (bool, error) {
clients.lock.Lock()
defer clients.lock.Unlock()
// check index
_, ok := clients.ipHost[ip]
if ok {
return false, nil
}
clients.ipHost[ip] = ClientHost{
Host: host,
Source: source,
}
log.Tracef("'%s': '%s' -> [%d]", host, ip, len(clients.ipHost))
return true, nil
}
// Parse system 'hosts' file and fill clients array
func clientsAddFromHostsFile() {
hostsFn := "/etc/hosts"
if runtime.GOOS == "windows" {
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
}
d, e := ioutil.ReadFile(hostsFn)
if e != nil {
log.Info("Can't read file %s: %v", hostsFn, e)
return
}
lines := strings.Split(string(d), "\n")
n := 0
for _, ln := range lines {
ln = strings.TrimSpace(ln)
if len(ln) == 0 || ln[0] == '#' {
continue
}
fields := strings.Fields(ln)
if len(fields) < 2 {
continue
}
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
if e != nil {
log.Tracef("%s", e)
}
if ok {
n++
}
}
log.Info("Added %d client aliases from %s", n, hostsFn)
}
type clientHostJSON struct {
IP string `json:"ip"`
Name string `json:"name"`
Source string `json:"source"`
}
type clientListJSON struct {
Clients []clientJSON `json:"clients"`
AutoClients []clientHostJSON `json:"auto_clients"`
}
// respond with information about configured clients
func handleGetClients(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
data := clientListJSON{}
clients.lock.Lock()
for _, c := range clients.list {
cj := clientJSON{
IP: c.IP,
MAC: c.MAC,
Name: c.Name,
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
ParentalEnabled: c.ParentalEnabled,
SafeSearchEnabled: c.SafeSearchEnabled,
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
}
if len(c.MAC) != 0 {
hwAddr, _ := net.ParseMAC(c.MAC)
ipAddr := dhcpServer.FindIPbyMAC(hwAddr)
if ipAddr != nil {
cj.IP = ipAddr.String()
}
}
data.Clients = append(data.Clients, cj)
}
for ip, ch := range clients.ipHost {
cj := clientHostJSON{
IP: ip,
Name: ch.Host,
}
cj.Source = "etc/hosts"
if ch.Source == ClientSourceRDNS {
cj.Source = "rDNS"
}
data.AutoClients = append(data.AutoClients, cj)
}
clients.lock.Unlock()
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
httpError(w, http.StatusInternalServerError, "Failed to encode to json: %v", e)
return
}
}
// Convert JSON object to Client object
func jsonToClient(cj clientJSON) (*Client, error) {
c := Client{
IP: cj.IP,
MAC: cj.MAC,
Name: cj.Name,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeSearchEnabled: cj.SafeSearchEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
}
return &c, nil
}
// Add a new client
func handleAddClient(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
c, err := jsonToClient(cj)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clientAdd(*c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Client already exists")
return
}
_ = writeAllConfigsAndReloadDNS()
returnOK(w)
}
// Remove client
func handleDelClient(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
if err != nil || len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if !clientDel(cj.Name) {
httpError(w, http.StatusBadRequest, "Client not found")
return
}
_ = writeAllConfigsAndReloadDNS()
returnOK(w)
}
type clientUpdateJSON struct {
Name string `json:"name"`
Data clientJSON `json:"data"`
}
// Update client's properties
func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
var dj clientUpdateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request")
return
}
c, err := jsonToClient(dj.Data)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
err = clientUpdate(dj.Name, *c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
_ = writeAllConfigsAndReloadDNS()
returnOK(w)
}
// RegisterClientsHandlers registers HTTP handlers
func RegisterClientsHandlers() {
http.HandleFunc("/control/clients", postInstall(optionalAuth(ensureGET(handleGetClients))))
http.HandleFunc("/control/clients/add", postInstall(optionalAuth(ensurePOST(handleAddClient))))
http.HandleFunc("/control/clients/delete", postInstall(optionalAuth(ensurePOST(handleDelClient))))
http.HandleFunc("/control/clients/update", postInstall(optionalAuth(ensurePOST(handleUpdateClient))))
}

122
home/clients_test.go Normal file
View File

@@ -0,0 +1,122 @@
package home
import "testing"
func TestClients(t *testing.T) {
var c Client
var e error
var b bool
clientsInit()
// add
c = Client{
IP: "1.1.1.1",
Name: "client1",
}
b, e = clientAdd(c)
if !b || e != nil {
t.Fatalf("clientAdd #1")
}
// add #2
c = Client{
IP: "2.2.2.2",
Name: "client2",
}
b, e = clientAdd(c)
if !b || e != nil {
t.Fatalf("clientAdd #2")
}
// failed add - name in use
c = Client{
IP: "1.2.3.5",
Name: "client1",
}
b, _ = clientAdd(c)
if b {
t.Fatalf("clientAdd - name in use")
}
// failed add - ip in use
c = Client{
IP: "2.2.2.2",
Name: "client3",
}
b, e = clientAdd(c)
if b || e == nil {
t.Fatalf("clientAdd - ip in use")
}
// get
if clientExists("1.2.3.4") {
t.Fatalf("clientExists")
}
if !clientExists("1.1.1.1") {
t.Fatalf("clientExists #1")
}
if !clientExists("2.2.2.2") {
t.Fatalf("clientExists #2")
}
// failed update - no such name
c.IP = "1.2.3.0"
c.Name = "client3"
if clientUpdate("client3", c) == nil {
t.Fatalf("clientUpdate")
}
// failed update - name in use
c.IP = "1.2.3.0"
c.Name = "client2"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - name in use")
}
// failed update - ip in use
c.IP = "2.2.2.2"
c.Name = "client1"
if clientUpdate("client1", c) == nil {
t.Fatalf("clientUpdate - ip in use")
}
// update
c.IP = "1.1.1.2"
c.Name = "client1"
if clientUpdate("client1", c) != nil {
t.Fatalf("clientUpdate")
}
// get after update
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
t.Fatalf("clientExists - get after update")
}
// failed remove - no such name
if clientDel("client3") {
t.Fatalf("clientDel - no such name")
}
// remove
if !clientDel("client1") || clientExists("1.1.1.2") {
t.Fatalf("clientDel")
}
// add host client
b, e = clientAddHost("1.1.1.1", "host", ClientSourceHostsFile)
if !b || e != nil {
t.Fatalf("clientAddHost")
}
// failed add - ip exists
b, e = clientAddHost("1.1.1.1", "host", ClientSourceHostsFile)
if b || e != nil {
t.Fatalf("clientAddHost - ip exists")
}
// get
if !clientExists("1.1.1.1") {
t.Fatalf("clientAddHost")
}
}

320
home/config.go Normal file
View File

@@ -0,0 +1,320 @@
package home
import (
"io/ioutil"
"os"
"path/filepath"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2"
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// logSettings
type logSettings struct {
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
}
type clientObject struct {
Name string `yaml:"name"`
IP string `yaml:"ip"`
MAC string `yaml:"mac"`
UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"`
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safebrowsing_enabled"`
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
// Raw file data to avoid re-reading of configuration file
// It's reset after config is parsed
fileData []byte
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
disableUpdate bool // If set, don't check for updates
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
AuthName string `yaml:"auth_name"` // AuthName is the basic auth username
AuthPass string `yaml:"auth_pass"` // AuthPass is the basic auth password
Language string `yaml:"language"` // two-letter ISO 639-1 language code
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
DNS dnsConfig `yaml:"dns"`
TLS tlsConfig `yaml:"tls"`
Filters []filter `yaml:"filters"`
UserRules []string `yaml:"user_rules"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
// Note: this array is filled only before file read/write and then it's cleared
Clients []clientObject `yaml:"clients"`
logSettings `yaml:",inline"`
sync.RWMutex `yaml:"-"`
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
}
// field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct {
BindHost string `yaml:"bind_host"`
Port int `yaml:"port"`
dnsforward.FilteringConfig `yaml:",inline"`
UpstreamDNS []string `yaml:"upstream_dns"`
}
var defaultDNS = []string{"https://dns.cloudflare.com/dns-query"}
var defaultBootstrap = []string{"1.1.1.1"}
type tlsConfigSettings struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https,omitempty"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
}
// field ordering is not important -- these are for API and are recalculated on each run
type tlsConfigStatus struct {
ValidCert bool `yaml:"-" json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
ValidChain bool `yaml:"-" json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
Subject string `yaml:"-" json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
Issuer string `yaml:"-" json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
NotBefore time.Time `yaml:"-" json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
NotAfter time.Time `yaml:"-" json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
DNSNames []string `yaml:"-" json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
// key status
ValidKey bool `yaml:"-" json:"valid_key"` // ValidKey is true if the key is a valid private key
KeyType string `yaml:"-" json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
// is usable? set by validator
ValidPair bool `yaml:"-" json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
// warnings
WarningValidation string `yaml:"-" json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
}
// field ordering is important -- yaml fields will mirror ordering from here
type tlsConfig struct {
tlsConfigSettings `yaml:",inline" json:",inline"`
tlsConfigStatus `yaml:"-" json:",inline"`
}
// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{
ourConfigFilename: "AdGuardHome.yaml",
BindPort: 3000,
BindHost: "0.0.0.0",
DNS: dnsConfig{
BindHost: "0.0.0.0",
Port: 53,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of dnsfilter features
FilteringEnabled: true, // whether or not use filter lists
BlockingMode: "nxdomain", // mode how to answer filtered requests
BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
Ratelimit: 20,
RefuseAny: true,
BootstrapDNS: defaultBootstrap,
AllServers: false,
},
UpstreamDNS: defaultDNS,
},
TLS: tlsConfig{
tlsConfigSettings: tlsConfigSettings{
PortHTTPS: 443,
PortDNSOverTLS: 853, // needs to be passed through to dnsproxy
},
},
Filters: []filter{
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},
{Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"},
{Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"},
{Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"},
},
DHCP: dhcpd.ServerConfig{
LeaseDuration: 86400,
ICMPTimeout: 1000,
},
SchemaVersion: currentSchemaVersion,
}
// init initializes default configuration for the current OS&ARCH
func init() {
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
// Use plain DNS on MIPS, encryption is too slow
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
// also change the default config
config.DNS.UpstreamDNS = defaultDNS
}
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(config.ourConfigFilename)
if err != nil {
if !os.IsNotExist(err) {
log.Error("unexpected error while config file path evaluation: %s", err)
}
configFile = config.ourConfigFilename
}
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(config.ourWorkingDir, configFile)
}
return configFile
}
// getLogSettings reads logging settings from the config file.
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied.
func getLogSettings() logSettings {
l := logSettings{}
yamlFile, err := readConfigFile()
if err != nil {
return l
}
err = yaml.Unmarshal(yamlFile, &l)
if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err)
}
return l
}
// parseConfig loads configuration from the YAML file
func parseConfig() error {
configFile := config.getConfigFilename()
log.Debug("Reading config file: %s", configFile)
yamlFile, err := readConfigFile()
if err != nil {
return err
}
config.fileData = nil
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
log.Error("Couldn't parse config file: %s", err)
return err
}
for _, cy := range config.Clients {
cli := Client{
Name: cy.Name,
IP: cy.IP,
MAC: cy.MAC,
UseOwnSettings: !cy.UseGlobalSettings,
FilteringEnabled: cy.FilteringEnabled,
ParentalEnabled: cy.ParentalEnabled,
SafeSearchEnabled: cy.SafeSearchEnabled,
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
}
_, err = clientAdd(cli)
if err != nil {
log.Tracef("clientAdd: %s", err)
}
}
config.Clients = nil
// Deduplicate filters
deduplicateFilters()
updateUniqueFilterID(config.Filters)
return nil
}
// readConfigFile reads config file contents if it exists
func readConfigFile() ([]byte, error) {
if len(config.fileData) != 0 {
return config.fileData, nil
}
configFile := config.getConfigFilename()
d, err := ioutil.ReadFile(configFile)
if err != nil {
log.Error("Couldn't read config file %s: %s", configFile, err)
return nil, err
}
return d, nil
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() error {
c.Lock()
defer c.Unlock()
clientsList := clientsGetList()
for _, cli := range clientsList {
ip := cli.IP
if len(cli.MAC) != 0 {
ip = ""
}
cy := clientObject{
Name: cli.Name,
IP: ip,
MAC: cli.MAC,
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchEnabled: cli.SafeSearchEnabled,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
}
config.Clients = append(config.Clients, cy)
}
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)
config.Clients = nil
if err != nil {
log.Error("Couldn't generate YAML file: %s", err)
return err
}
err = file.SafeWrite(configFile, yamlText)
if err != nil {
log.Error("Couldn't save YAML config: %s", err)
return err
}
return nil
}
func writeAllConfigs() error {
err := config.write()
if err != nil {
log.Error("Couldn't write config: %s", err)
return err
}
userFilter := userFilter()
err = userFilter.save()
if err != nil {
log.Error("Couldn't save the user filter: %s", err)
return err
}
return nil
}

1002
home/control.go Normal file

File diff suppressed because it is too large Load Diff

87
home/control_access.go Normal file
View File

@@ -0,0 +1,87 @@
package home
import (
"encoding/json"
"net"
"net/http"
"github.com/AdguardTeam/golibs/log"
)
type accessListJSON struct {
AllowedClients []string `json:"allowed_clients"`
DisallowedClients []string `json:"disallowed_clients"`
BlockedHosts []string `json:"blocked_hosts"`
}
func handleAccessList(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
controlLock.Lock()
j := accessListJSON{
AllowedClients: config.DNS.AllowedClients,
DisallowedClients: config.DNS.DisallowedClients,
BlockedHosts: config.DNS.BlockedHosts,
}
controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(j)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
func checkIPCIDRArray(src []string) error {
for _, s := range src {
ip := net.ParseIP(s)
if ip != nil {
continue
}
_, _, err := net.ParseCIDR(s)
if err != nil {
return err
}
}
return nil
}
func handleAccessSet(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
j := accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&j)
if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
err = checkIPCIDRArray(j.AllowedClients)
if err == nil {
err = checkIPCIDRArray(j.DisallowedClients)
}
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
config.Lock()
config.DNS.AllowedClients = j.AllowedClients
config.DNS.DisallowedClients = j.DisallowedClients
config.DNS.BlockedHosts = j.BlockedHosts
config.Unlock()
log.Tracef("Update access lists: %d, %d, %d",
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
err = writeAllConfigsAndReloadDNS()
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
returnOK(w)
}

278
home/control_install.go Normal file
View File

@@ -0,0 +1,278 @@
package home
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os/exec"
"strconv"
"github.com/AdguardTeam/golibs/log"
)
type firstRunData struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]interface{} `json:"interfaces"`
}
// Get initial installation settings
func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
data := firstRunData{}
data.WebPort = 80
data.DNSPort = 53
ifaces, err := getValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
data.Interfaces = make(map[string]interface{})
for _, iface := range ifaces {
data.Interfaces[iface.Name] = iface
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
return
}
}
type checkConfigReqEnt struct {
Port int `json:"port"`
IP string `json:"ip"`
Autofix bool `json:"autofix"`
}
type checkConfigReq struct {
Web checkConfigReqEnt `json:"web"`
DNS checkConfigReqEnt `json:"dns"`
}
type checkConfigRespEnt struct {
Status string `json:"status"`
CanAutofix bool `json:"can_autofix"`
}
type checkConfigResp struct {
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
}
// Check if ports are available, respond with results
func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
reqData := checkConfigReq{}
respData := checkConfigResp{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port)
if err != nil {
respData.Web.Status = fmt.Sprintf("%v", err)
}
}
if reqData.DNS.Port != 0 {
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
if errorIsAddrInUse(err) {
canAutofix := checkDNSStubListener()
if canAutofix && reqData.DNS.Autofix {
err = disableDNSStubListener()
if err != nil {
log.Error("Couldn't disable DNSStubListener: %s", err)
}
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
canAutofix = false
}
respData.DNS.CanAutofix = canAutofix
}
if err == nil {
err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
}
if err != nil {
respData.DNS.Status = fmt.Sprintf("%v", err)
}
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(respData)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal JSON: %s", err)
return
}
}
// Check if DNSStubListener is active
func checkDNSStubListener() bool {
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Error("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Error("command %s has failed: %v code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return false
}
return true
}
// Deactivate DNSStubListener
func disableDNSStubListener() error {
cmd := exec.Command("sed", "-r", "-i.orig", "s/#?DNSStubListener=yes/DNSStubListener=no/g", "/etc/systemd/resolved.conf")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err := cmd.Output()
if err != nil {
return err
}
if cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("process %s exited with an error: %d",
cmd.Path, cmd.ProcessState.ExitCode())
}
cmd = exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
_, err = cmd.Output()
if err != nil {
return err
}
if cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("process %s exited with an error: %d",
cmd.Path, cmd.ProcessState.ExitCode())
}
return nil
}
type applyConfigReqEnt struct {
IP string `json:"ip"`
Port int `json:"port"`
}
type applyConfigReq struct {
Web applyConfigReqEnt `json:"web"`
DNS applyConfigReqEnt `json:"dns"`
Username string `json:"username"`
Password string `json:"password"`
}
// Copy installation parameters between two configuration objects
func copyInstallSettings(dst *configuration, src *configuration) {
dst.BindHost = src.BindHost
dst.BindPort = src.BindPort
dst.DNS.BindHost = src.DNS.BindHost
dst.DNS.Port = src.DNS.Port
dst.AuthName = src.AuthName
dst.AuthPass = src.AuthPass
}
// Apply new configuration, start DNS server, restart Web server
func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
newSettings := applyConfigReq{}
err := json.NewDecoder(r.Body).Decode(&newSettings)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'configure' JSON: %s", err)
return
}
if newSettings.Web.Port == 0 || newSettings.DNS.Port == 0 {
httpError(w, http.StatusBadRequest, "port value can't be 0")
return
}
restartHTTP := true
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
// no need to rebind
restartHTTP = false
}
// validate that hosts and ports are bindable
if restartHTTP {
err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
return
}
}
err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
var curConfig configuration
copyInstallSettings(&curConfig, &config)
config.firstRun = false
config.BindHost = newSettings.Web.IP
config.BindPort = newSettings.Web.Port
config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port
config.AuthName = newSettings.Username
config.AuthPass = newSettings.Password
err = startDNSServer()
if err != nil {
config.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't start DNS server: %s", err)
return
}
err = config.write()
if err != nil {
config.firstRun = true
copyInstallSettings(&config, &curConfig)
httpError(w, http.StatusInternalServerError, "Couldn't write config: %s", err)
return
}
go refreshFiltersIfNecessary(false)
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP {
go func() {
httpServer.Shutdown(context.TODO())
}()
}
returnOK(w)
}
func registerInstallHandlers() {
http.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(handleInstallGetAddresses)))
http.HandleFunc("/control/install/check_config", preInstall(ensurePOST(handleInstallCheckConfig)))
http.HandleFunc("/control/install/configure", preInstall(ensurePOST(handleInstallConfigure)))
}

153
home/control_test.go Normal file
View File

@@ -0,0 +1,153 @@
package home
import (
"testing"
"time"
)
/* Tests performed:
. Bad certificate
. Bad private key
. Valid certificate & private key */
func TestValidateCertificates(t *testing.T) {
var data tlsConfigStatus
// bad cert
data = validateCertificates("bad cert", "", "")
if !(data.WarningValidation != "" &&
!data.ValidCert &&
!data.ValidChain) {
t.Fatalf("bad cert: validateCertificates(): %v", data)
}
// bad priv key
data = validateCertificates("", "bad priv key", "")
if !(data.WarningValidation != "" &&
!data.ValidKey) {
t.Fatalf("bad priv key: validateCertificates(): %v", data)
}
// valid cert & priv key
CertificateChain := `-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----`
PrivateKey := `-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----`
data = validateCertificates(CertificateChain, PrivateKey, "")
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
if !(data.WarningValidation != "" /* self signed */ &&
data.ValidCert &&
!data.ValidChain &&
data.ValidKey &&
data.KeyType == "RSA" &&
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.NotBefore == notBefore &&
data.NotAfter == notAfter &&
// data.DNSNames[0] == &&
data.ValidPair) {
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
}
}
func TestValidateUpstream(t *testing.T) {
invalidUpstreams := []string{"1.2.3.4.5",
"123.3.7m",
"htttps://google.com/dns-query",
"[/host.com]tls://dns.adguard.com",
"[host.ru]#",
}
validDefaultUpstreams := []string{"1.1.1.1",
"tls://1.1.1.1",
"https://dns.adguard.com/dns-query",
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}
validUpstreams := []string{"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}
for _, u := range invalidUpstreams {
_, err := validateUpstream(u)
if err == nil {
t.Fatalf("upstream %s is invalid but it pass through validation", u)
}
}
for _, u := range validDefaultUpstreams {
defaultUpstream, err := validateUpstream(u)
if err != nil {
t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
}
if !defaultUpstream {
t.Fatalf("upstream %s is default one!", u)
}
}
for _, u := range validUpstreams {
defaultUpstream, err := validateUpstream(u)
if err != nil {
t.Fatalf("upstream %s is valid but it doen't pass through validation cause: %s", u, err)
}
if defaultUpstream {
t.Fatalf("upstream %s is default one!", u)
}
}
}
func TestValidateUpstreamsSet(t *testing.T) {
// Set of valid upstreams. There is no default upstream specified
upstreamsSet := []string{"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}
err := validateUpstreams(upstreamsSet)
if err == nil {
t.Fatalf("there is no default upstream")
}
// Let's add default upstream
upstreamsSet = append(upstreamsSet, "8.8.8.8")
err = validateUpstreams(upstreamsSet)
if err != nil {
t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err)
}
// Let's add invalid upstream
upstreamsSet = append(upstreamsSet, "dhcp://fake.dns")
err = validateUpstreams(upstreamsSet)
if err == nil {
t.Fatalf("there is an invalid upstream in set, but it pass through validation")
}
}

338
home/control_tls.go Normal file
View File

@@ -0,0 +1,338 @@
// Control: TLS configuring handlers
package home
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"reflect"
"strings"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
// RegisterTLSHandlers registers HTTP handlers for TLS configuration
func RegisterTLSHandlers() {
http.HandleFunc("/control/tls/status", postInstall(optionalAuth(ensureGET(handleTLSStatus))))
http.HandleFunc("/control/tls/configure", postInstall(optionalAuth(ensurePOST(handleTLSConfigure))))
http.HandleFunc("/control/tls/validate", postInstall(optionalAuth(ensurePOST(handleTLSValidate))))
}
func handleTLSStatus(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
marshalTLS(w, config.TLS)
}
func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
data, err := unmarshalTLS(r)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return
}
}
data.tlsConfigStatus = validateCertificates(data.CertificateChain, data.PrivateKey, data.ServerName)
marshalTLS(w, data)
}
func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
data, err := unmarshalTLS(r)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
// check if port is available
// BUT: if we are already using this port, no need
alreadyRunning := false
if httpsServer.server != nil {
alreadyRunning = true
}
if !alreadyRunning {
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
if err != nil {
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
return
}
}
restartHTTPS := false
data.tlsConfigStatus = validateCertificates(data.CertificateChain, data.PrivateKey, data.ServerName)
if !reflect.DeepEqual(config.TLS.tlsConfigSettings, data.tlsConfigSettings) {
log.Printf("tls config settings have changed, will restart HTTPS server")
restartHTTPS = true
}
config.TLS = data
err = writeAllConfigsAndReloadDNS()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
return
}
marshalTLS(w, data)
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTPS {
go func() {
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
httpsServer.cond.L.Lock()
httpsServer.cond.Broadcast()
if httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO())
}
httpsServer.cond.L.Unlock()
}()
}
}
func verifyCertChain(data *tlsConfigStatus, certChain string, serverName string) error {
log.Tracef("got certificate: %s", certChain)
// now do a more extended validation
var certs []*pem.Block // PEM-encoded certificates
var skippedBytes []string // skipped bytes
pemblock := []byte(certChain)
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}
if decoded.Type == "CERTIFICATE" {
certs = append(certs, decoded)
} else {
skippedBytes = append(skippedBytes, decoded.Type)
}
}
var parsedCerts []*x509.Certificate
for _, cert := range certs {
parsed, err := x509.ParseCertificate(cert.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
return errors.New(data.WarningValidation)
}
parsedCerts = append(parsedCerts, parsed)
}
if len(parsedCerts) == 0 {
data.WarningValidation = fmt.Sprintf("You have specified an empty certificate")
return errors.New(data.WarningValidation)
}
data.ValidCert = true
// spew.Dump(parsedCerts)
opts := x509.VerifyOptions{
DNSName: serverName,
}
log.Printf("number of certs - %d", len(parsedCerts))
if len(parsedCerts) > 1 {
// set up an intermediate
pool := x509.NewCertPool()
for _, cert := range parsedCerts[1:] {
log.Printf("got an intermediate cert")
pool.AddCert(cert)
}
opts.Intermediates = pool
}
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem
mainCert := parsedCerts[0]
_, err := mainCert.Verify(opts)
if err != nil {
// let self-signed certs through
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err)
} else {
data.ValidChain = true
}
// spew.Dump(chains)
// update status
if mainCert != nil {
notAfter := mainCert.NotAfter
data.Subject = mainCert.Subject.String()
data.Issuer = mainCert.Issuer.String()
data.NotAfter = notAfter
data.NotBefore = mainCert.NotBefore
data.DNSNames = mainCert.DNSNames
}
return nil
}
func validatePkey(data *tlsConfigStatus, pkey string) error {
// now do a more extended validation
var key *pem.Block // PEM-encoded certificates
var skippedBytes []string // skipped bytes
// go through all pem blocks, but take first valid pem block and drop the rest
pemblock := []byte(pkey)
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}
if decoded.Type == "PRIVATE KEY" || strings.HasSuffix(decoded.Type, " PRIVATE KEY") {
key = decoded
break
} else {
skippedBytes = append(skippedBytes, decoded.Type)
}
}
if key == nil {
data.WarningValidation = "No valid keys were found"
return errors.New(data.WarningValidation)
}
// parse the decoded key
_, keytype, err := parsePrivateKey(key.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
return errors.New(data.WarningValidation)
}
data.ValidKey = true
data.KeyType = keytype
return nil
}
// Process certificate data and its private key.
// All parameters are optional.
// On error, return partially set object
// with 'WarningValidation' field containing error description.
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus {
var data tlsConfigStatus
// check only public certificate separately from the key
if certChain != "" {
if verifyCertChain(&data, certChain, serverName) != nil {
return data
}
}
// validate private key (right now the only validation possible is just parsing it)
if pkey != "" {
if validatePkey(&data, pkey) != nil {
return data
}
}
// if both are set, validate both in unison
if pkey != "" && certChain != "" {
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
if err != nil {
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err)
return data
}
data.ValidPair = true
}
return data
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, "RSA", nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey:
return key, "RSA", nil
case *ecdsa.PrivateKey:
return key, "ECDSA", nil
default:
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, "ECDSA", nil
}
return nil, "", errors.New("tls: failed to parse private key")
}
// unmarshalTLS handles base64-encoded certificates transparently
func unmarshalTLS(r *http.Request) (tlsConfig, error) {
data := tlsConfig{}
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
return data, errorx.Decorate(err, "Failed to parse new TLS config json")
}
if data.CertificateChain != "" {
certPEM, err := base64.StdEncoding.DecodeString(data.CertificateChain)
if err != nil {
return data, errorx.Decorate(err, "Failed to base64-decode certificate chain")
}
data.CertificateChain = string(certPEM)
}
if data.PrivateKey != "" {
keyPEM, err := base64.StdEncoding.DecodeString(data.PrivateKey)
if err != nil {
return data, errorx.Decorate(err, "Failed to base64-decode private key")
}
data.PrivateKey = string(keyPEM)
}
return data, nil
}
func marshalTLS(w http.ResponseWriter, data tlsConfig) {
w.Header().Set("Content-Type", "application/json")
if data.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
data.CertificateChain = encoded
}
if data.PrivateKey != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.PrivateKey))
data.PrivateKey = encoded
}
err := json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to marshal json with TLS status: %s", err)
return
}
}

511
home/control_update.go Normal file
View File

@@ -0,0 +1,511 @@
package home
import (
"archive/tar"
"archive/zip"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/golibs/log"
)
// Convert version.json data to our JSON response
func getVersionResp(data []byte) []byte {
versionJSON := make(map[string]interface{})
err := json.Unmarshal(data, &versionJSON)
if err != nil {
log.Error("version.json: %s", err)
return []byte{}
}
ret := make(map[string]interface{})
ret["can_autoupdate"] = false
var ok1, ok2, ok3 bool
ret["new_version"], ok1 = versionJSON["version"].(string)
ret["announcement"], ok2 = versionJSON["announcement"].(string)
ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string)
selfUpdateMinVersion, ok4 := versionJSON["selfupdate_min_version"].(string)
if !ok1 || !ok2 || !ok3 || !ok4 {
log.Error("version.json: invalid data")
return []byte{}
}
_, ok := versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)]
if ok && ret["new_version"] != VersionString && VersionString >= selfUpdateMinVersion {
ret["can_autoupdate"] = true
}
d, _ := json.Marshal(ret)
return d
}
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
if config.disableUpdate {
log.Tracef("New app version check is disabled by user")
return
}
now := time.Now()
controlLock.Lock()
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
data := versionCheckJSON
controlLock.Unlock()
if cached {
// return cached copy
w.Header().Set("Content-Type", "application/json")
w.Write(getVersionResp(data))
return
}
resp, err := client.Get(versionCheckURL)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err)
return
}
controlLock.Lock()
versionCheckLastTime = now
versionCheckJSON = body
controlLock.Unlock()
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(body))
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
// Copy file on disk
func copyFile(src, dst string) error {
d, e := ioutil.ReadFile(src)
if e != nil {
return e
}
e = ioutil.WriteFile(dst, d, 0644)
if e != nil {
return e
}
return nil
}
type updateInfo struct {
pkgURL string // URL for the new package
pkgName string // Full path to package file
newVer string // New version string
updateDir string // Full path to the directory containing unpacked files from the new package
backupDir string // Full path to backup directory
configName string // Full path to the current configuration file
updateConfigName string // Full path to the configuration file to check by the new binary
curBinName string // Full path to the current executable file
bkpBinName string // Full path to the current executable file in backup directory
newBinName string // Full path to the new executable file
}
// Fill in updateInfo object
func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
var u updateInfo
workDir := config.ourWorkingDir
versionJSON := make(map[string]interface{})
err := json.Unmarshal(jsonData, &versionJSON)
if err != nil {
return nil, fmt.Errorf("JSON parse: %s", err)
}
u.pkgURL = versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)].(string)
u.newVer = versionJSON["version"].(string)
if len(u.pkgURL) == 0 || len(u.newVer) == 0 {
return nil, fmt.Errorf("Invalid JSON")
}
if u.newVer == VersionString {
return nil, fmt.Errorf("No need to update")
}
u.updateDir = filepath.Join(workDir, fmt.Sprintf("agh-update-%s", u.newVer))
u.backupDir = filepath.Join(workDir, fmt.Sprintf("agh-backup-%s", VersionString))
_, pkgFileName := filepath.Split(u.pkgURL)
if len(pkgFileName) == 0 {
return nil, fmt.Errorf("Invalid JSON")
}
u.pkgName = filepath.Join(u.updateDir, pkgFileName)
u.configName = config.getConfigFilename()
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml")
if strings.HasSuffix(pkgFileName, ".zip") {
u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml")
}
binName := "AdGuardHome"
if runtime.GOOS == "windows" {
binName = "AdGuardHome.exe"
}
u.curBinName = filepath.Join(workDir, binName)
u.bkpBinName = filepath.Join(u.backupDir, binName)
u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName)
if strings.HasSuffix(pkgFileName, ".zip") {
u.newBinName = filepath.Join(u.updateDir, binName)
}
return &u, nil
}
// Unpack all files from .zip file to the specified directory
// Existing files are overwritten
// Return the list of files (not directories) written
func zipFileUnpack(zipfile, outdir string) ([]string, error) {
r, err := zip.OpenReader(zipfile)
if err != nil {
return nil, fmt.Errorf("zip.OpenReader(): %s", err)
}
defer r.Close()
var files []string
var err2 error
var zr io.ReadCloser
for _, zf := range r.File {
zr, err = zf.Open()
if err != nil {
err2 = fmt.Errorf("zip file Open(): %s", err)
break
}
fi := zf.FileInfo()
if len(fi.Name()) == 0 {
continue
}
fn := filepath.Join(outdir, fi.Name())
if fi.IsDir() {
err = os.Mkdir(fn, fi.Mode())
if err != nil && !os.IsExist(err) {
err2 = fmt.Errorf("os.Mkdir(): %s", err)
break
}
log.Tracef("created directory %s", fn)
continue
}
f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
err2 = fmt.Errorf("os.OpenFile(): %s", err)
break
}
_, err = io.Copy(f, zr)
if err != nil {
f.Close()
err2 = fmt.Errorf("io.Copy(): %s", err)
break
}
f.Close()
log.Tracef("created file %s", fn)
files = append(files, fi.Name())
}
zr.Close()
return files, err2
}
// Unpack all files from .tar.gz file to the specified directory
// Existing files are overwritten
// Return the list of files (not directories) written
func targzFileUnpack(tarfile, outdir string) ([]string, error) {
f, err := os.Open(tarfile)
if err != nil {
return nil, fmt.Errorf("os.Open(): %s", err)
}
defer f.Close()
gzReader, err := gzip.NewReader(f)
if err != nil {
return nil, fmt.Errorf("gzip.NewReader(): %s", err)
}
var files []string
var err2 error
tarReader := tar.NewReader(gzReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
err2 = nil
break
}
if err != nil {
err2 = fmt.Errorf("tarReader.Next(): %s", err)
break
}
if len(header.Name) == 0 {
continue
}
fn := filepath.Join(outdir, header.Name)
if header.Typeflag == tar.TypeDir {
err = os.Mkdir(fn, os.FileMode(header.Mode&0777))
if err != nil && !os.IsExist(err) {
err2 = fmt.Errorf("os.Mkdir(%s): %s", fn, err)
break
}
log.Tracef("created directory %s", fn)
continue
} else if header.Typeflag != tar.TypeReg {
log.Tracef("%s: unknown file type %d, skipping", header.Name, header.Typeflag)
continue
}
f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777))
if err != nil {
err2 = fmt.Errorf("os.OpenFile(%s): %s", fn, err)
break
}
_, err = io.Copy(f, tarReader)
if err != nil {
f.Close()
err2 = fmt.Errorf("io.Copy(): %s", err)
break
}
f.Close()
log.Tracef("created file %s", fn)
files = append(files, header.Name)
}
gzReader.Close()
return files, err2
}
func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error {
for _, f := range files {
_, name := filepath.Split(f)
if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" {
continue
}
src := filepath.Join(srcdir, f)
if useSrcNameOnly {
src = filepath.Join(srcdir, name)
}
dst := filepath.Join(dstdir, f)
if useDstNameOnly {
dst = filepath.Join(dstdir, name)
}
err := copyFile(src, dst)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Tracef("Copied: %s -> %s", src, dst)
}
return nil
}
// Download package file and save it to disk
func getPackageFile(u *updateInfo) error {
resp, err := client.Get(u.pkgURL)
if err != nil {
return fmt.Errorf("HTTP request failed: %s", err)
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
log.Tracef("Reading HTTP body")
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ioutil.ReadAll() failed: %s", err)
}
log.Tracef("Saving package to file")
err = ioutil.WriteFile(u.pkgName, body, 0644)
if err != nil {
return fmt.Errorf("ioutil.WriteFile() failed: %s", err)
}
return nil
}
// Perform an update procedure
func doUpdate(u *updateInfo) error {
log.Info("Updating from %s to %s. URL:%s Package:%s",
VersionString, u.newVer, u.pkgURL, u.pkgName)
_ = os.Mkdir(u.updateDir, 0755)
var err error
err = getPackageFile(u)
if err != nil {
return err
}
log.Tracef("Unpacking the package")
_, file := filepath.Split(u.pkgName)
var files []string
if strings.HasSuffix(file, ".zip") {
files, err = zipFileUnpack(u.pkgName, u.updateDir)
if err != nil {
return fmt.Errorf("zipFileUnpack() failed: %s", err)
}
} else if strings.HasSuffix(file, ".tar.gz") {
files, err = targzFileUnpack(u.pkgName, u.updateDir)
if err != nil {
return fmt.Errorf("targzFileUnpack() failed: %s", err)
}
} else {
return fmt.Errorf("Unknown package extension")
}
log.Tracef("Checking configuration")
err = copyFile(u.configName, u.updateConfigName)
if err != nil {
return fmt.Errorf("copyFile() failed: %s", err)
}
cmd := exec.Command(u.newBinName, "--check-config")
err = cmd.Run()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
}
log.Tracef("Backing up the current configuration")
_ = os.Mkdir(u.backupDir, 0755)
err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
if err != nil {
return fmt.Errorf("copyFile() failed: %s", err)
}
// ./README.md -> backup/README.md
err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true)
if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
config.ourWorkingDir, u.backupDir, err)
}
// update/[AdGuardHome/]README.md -> ./README.md
err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true)
if err != nil {
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
u.updateDir, config.ourWorkingDir, err)
}
log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName)
err = os.Rename(u.curBinName, u.bkpBinName)
if err != nil {
return err
}
if runtime.GOOS == "windows" {
// rename fails with "File in use" error
err = copyFile(u.newBinName, u.curBinName)
} else {
err = os.Rename(u.newBinName, u.curBinName)
}
if err != nil {
return err
}
log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName)
_ = os.Remove(u.pkgName)
_ = os.RemoveAll(u.updateDir)
return nil
}
// Complete an update procedure
func finishUpdate(u *updateInfo) {
log.Info("Stopping all tasks")
cleanup()
stopHTTPServer()
cleanupAlways()
if runtime.GOOS == "windows" {
if config.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
}
cmd := exec.Command(u.curBinName, os.Args[1:]...)
log.Info("Restarting: %v", cmd.Args)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
}
os.Exit(0)
} else {
log.Info("Restarting: %v", os.Args)
err := syscall.Exec(u.curBinName, os.Args, os.Environ())
if err != nil {
log.Fatalf("syscall.Exec() failed: %s", err)
}
// Unreachable code
}
}
// Perform an update procedure to the latest available version
func handleUpdate(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
if len(versionCheckJSON) == 0 {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
u, err := getUpdateInfo(versionCheckJSON)
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
err = doUpdate(u)
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
returnOK(w)
time.Sleep(time.Second) // wait (hopefully) until response is sent (not sure whether it's really necessary)
go finishUpdate(u)
}

View File

@@ -0,0 +1,54 @@
// +build ignore
package home
import (
"os"
"testing"
)
func TestDoUpdate(t *testing.T) {
config.DNS.Port = 0
config.ourWorkingDir = "."
u := updateInfo{
pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz",
pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz",
newVer: "v0.95",
updateDir: "./agh-update-v0.95",
backupDir: "./agh-backup-v0.94",
configName: "./AdGuardHome.yaml",
updateConfigName: "./agh-update-v0.95/AdGuardHome/AdGuardHome.yaml",
curBinName: "./AdGuardHome",
bkpBinName: "./agh-backup-v0.94/AdGuardHome",
newBinName: "./agh-update-v0.95/AdGuardHome/AdGuardHome",
}
e := doUpdate(&u)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
os.RemoveAll(u.backupDir)
}
func TestTargzFileUnpack(t *testing.T) {
fn := "./dist/AdGuardHome_v0.95_linux_amd64.tar.gz"
outdir := "./test-unpack"
_ = os.Mkdir(outdir, 0755)
files, e := targzFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
os.RemoveAll(outdir)
}
func TestZipFileUnpack(t *testing.T) {
fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip"
outdir := "./test-unpack"
_ = os.Mkdir(outdir, 0755)
files, e := zipFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
os.RemoveAll(outdir)
}

460
home/dhcp.go Normal file
View File

@@ -0,0 +1,460 @@
package home
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"os/exec"
"runtime"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
var dhcpServer = dhcpd.Server{}
// []dhcpd.Lease -> JSON
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
leases := []map[string]string{}
for _, l := range inputLeases {
lease := map[string]string{
"mac": l.HWAddr.String(),
"ip": l.IP.String(),
"hostname": l.Hostname,
}
if includeExpires {
lease["expires"] = l.Expiry.Format(time.RFC3339)
}
leases = append(leases, lease)
}
return leases
}
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
leases := convertLeases(dhcpServer.Leases(), true)
staticLeases := convertLeases(dhcpServer.StaticLeases(), false)
status := map[string]interface{}{
"config": config.DHCP,
"leases": leases,
"static_leases": staticLeases,
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
return
}
}
type leaseJSON struct {
HWAddr string `json:"mac"`
IP string `json:"ip"`
Hostname string `json:"hostname"`
}
type dhcpServerConfigJSON struct {
dhcpd.ServerConfig `json:",inline"`
StaticLeases []leaseJSON `json:"static_leases"`
}
func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
newconfig := dhcpServerConfigJSON{}
err := json.NewDecoder(r.Body).Decode(&newconfig)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
return
}
err = dhcpServer.CheckConfig(newconfig.ServerConfig)
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return
}
err = dhcpServer.Stop()
if err != nil {
log.Error("failed to stop the DHCP server: %s", err)
}
err = dhcpServer.Init(newconfig.ServerConfig)
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
return
}
if newconfig.Enabled {
staticIP, err := hasStaticIP(newconfig.InterfaceName)
if !staticIP && err == nil {
err = setStaticIP(newconfig.InterfaceName)
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
return
}
}
err = dhcpServer.Start()
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
return
}
}
config.DHCP = newconfig.ServerConfig
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
response := map[string]interface{}{}
ifaces, err := getValidNetInterfaces()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 {
// it's a loopback, skip it
continue
}
if iface.Flags&net.FlagBroadcast == 0 {
// this interface doesn't support broadcast, skip it
continue
}
addrs, err := iface.Addrs()
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
return
}
jsonIface := netInterface{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr.String(),
}
if iface.Flags != 0 {
jsonIface.Flags = iface.Flags.String()
}
// we don't want link-local addresses in json, so skip them
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
httpError(w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
return
}
// ignore link-local
if ipnet.IP.IsLinkLocalUnicast() {
continue
}
jsonIface.Addresses = append(jsonIface.Addresses, ipnet.IP.String())
}
if len(jsonIface.Addresses) != 0 {
response[iface.Name] = jsonIface
}
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
return
}
}
// Perform the following tasks:
// . Search for another DHCP server running
// . Check if a static IP is configured for the network interface
// Respond with results
func handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to read request body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
interfaceName := strings.TrimSpace(string(body))
if interfaceName == "" {
errorText := fmt.Sprintf("empty interface name specified")
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
found, err := dhcpd.CheckIfOtherDHCPServersPresent(interfaceName)
othSrv := map[string]interface{}{}
foundVal := "no"
if found {
foundVal = "yes"
} else if err != nil {
foundVal = "error"
othSrv["error"] = err.Error()
}
othSrv["found"] = foundVal
staticIP := map[string]interface{}{}
isStaticIP, err := hasStaticIP(interfaceName)
staticIPStatus := "yes"
if err != nil {
staticIPStatus = "error"
staticIP["error"] = err.Error()
} else if !isStaticIP {
staticIPStatus = "no"
staticIP["ip"] = getFullIP(interfaceName)
}
staticIP["static"] = staticIPStatus
result := map[string]interface{}{}
result["other_server"] = othSrv
result["static_ip"] = staticIP
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
if err != nil {
httpError(w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
return
}
}
// Check if network interface has a static IP configured
func hasStaticIP(ifaceName string) (bool, error) {
if runtime.GOOS == "windows" {
return false, errors.New("Can't detect static IP: not supported on Windows")
}
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {
return false, err
}
lines := strings.Split(string(body), "\n")
nameLine := fmt.Sprintf("interface %s", ifaceName)
withinInterfaceCtx := false
for _, line := range lines {
line = strings.TrimSpace(line)
if withinInterfaceCtx && len(line) == 0 {
// an empty line resets our state
withinInterfaceCtx = false
}
if len(line) == 0 || line[0] == '#' {
continue
}
line = strings.TrimSpace(line)
if !withinInterfaceCtx {
if line == nameLine {
// we found our interface
withinInterfaceCtx = true
}
} else {
if strings.HasPrefix(line, "interface ") {
// we found another interface - reset our state
withinInterfaceCtx = false
continue
}
if strings.HasPrefix(line, "static ip_address=") {
return true, nil
}
}
}
return false, nil
}
// Get IP address with netmask
func getFullIP(ifaceName string) string {
cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return ""
}
fields := strings.Fields(string(d))
if len(fields) < 4 {
return ""
}
_, _, err = net.ParseCIDR(fields[3])
if err != nil {
return ""
}
return fields[3]
}
// Get gateway IP address
func getGatewayIP(ifaceName string) string {
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return ""
}
fields := strings.Fields(string(d))
if len(fields) < 3 || fields[0] != "default" {
return ""
}
ip := net.ParseIP(fields[2])
if ip == nil {
return ""
}
return fields[2]
}
// Set a static IP for network interface
func setStaticIP(ifaceName string) error {
ip := getFullIP(ifaceName)
if len(ip) == 0 {
return errors.New("Can't get IP address")
}
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {
return err
}
ip4, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
ifaceName, ip)
body = append(body, []byte(add)...)
gatewayIP := getGatewayIP(ifaceName)
if len(gatewayIP) != 0 {
add = fmt.Sprintf("static routers=%s\n",
gatewayIP)
body = append(body, []byte(add)...)
}
add = fmt.Sprintf("static domain_name_servers=%s\n\n",
ip4)
body = append(body, []byte(add)...)
err = file.SafeWrite("/etc/dhcpcd.conf", body)
if err != nil {
return err
}
return nil
}
func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
lj := leaseJSON{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := parseIPv4(lj.IP)
if ip == nil {
httpError(w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := dhcpd.Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = dhcpServer.AddStaticLease(lease)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
returnOK(w)
}
func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
lj := leaseJSON{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := parseIPv4(lj.IP)
if ip == nil {
httpError(w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := dhcpd.Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = dhcpServer.RemoveStaticLease(lease)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
returnOK(w)
}
func startDHCPServer() error {
if !config.DHCP.Enabled {
// not enabled, don't do anything
return nil
}
err := dhcpServer.Init(config.DHCP)
if err != nil {
return errorx.Decorate(err, "Couldn't init DHCP server")
}
err = dhcpServer.Start()
if err != nil {
return errorx.Decorate(err, "Couldn't start DHCP server")
}
return nil
}
func stopDHCPServer() error {
if !config.DHCP.Enabled {
return nil
}
err := dhcpServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop DHCP server")
}
return nil
}

274
home/dns.go Normal file
View File

@@ -0,0 +1,274 @@
package home
import (
"fmt"
"net"
"os"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
"github.com/miekg/dns"
)
var dnsServer *dnsforward.Server
const (
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
)
type dnsContext struct {
rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread
// contains IP addresses of clients to be resolved by rDNS
// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
rdnsIP map[string]bool
rdnsLock sync.Mutex // synchronize access to rdnsIP
upstream upstream.Upstream // Upstream object for our own DNS server
}
var dnsctx dnsContext
// initDNSServer creates an instance of the dnsforward.Server
// Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats
func initDNSServer(baseDir string) {
err := os.MkdirAll(baseDir, 0755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
}
dnsServer = dnsforward.NewServer(baseDir)
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
opts := upstream.Options{
Timeout: rdnsTimeout,
}
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
if err != nil {
log.Error("upstream.AddressToUpstream: %s", err)
return
}
dnsctx.rdnsIP = make(map[string]bool)
dnsctx.rdnsChannel = make(chan string, 256)
go asyncRDNSLoop()
}
func isRunning() bool {
return dnsServer != nil && dnsServer.IsRunning()
}
func beginAsyncRDNS(ip string) {
if clientExists(ip) {
return
}
// add IP to rdnsIP, if not exists
dnsctx.rdnsLock.Lock()
defer dnsctx.rdnsLock.Unlock()
_, ok := dnsctx.rdnsIP[ip]
if ok {
return
}
dnsctx.rdnsIP[ip] = true
log.Tracef("Adding %s for rDNS resolve", ip)
select {
case dnsctx.rdnsChannel <- ip:
//
default:
log.Tracef("rDNS queue is full")
}
}
// Use rDNS to get hostname by IP address
func resolveRDNS(ip string) string {
log.Tracef("Resolving host for %s", ip)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
}
var err error
req.Question[0].Name, err = dns.ReverseAddr(ip)
if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return ""
}
resp, err := dnsctx.upstream.Exchange(&req)
if err != nil {
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
return ""
}
if len(resp.Answer) != 1 {
log.Debug("No answer for rDNS lookup of %s", ip)
return ""
}
ptr, ok := resp.Answer[0].(*dns.PTR)
if !ok {
log.Error("not a PTR response for %s", ip)
return ""
}
log.Tracef("PTR response for %s: %s", ip, ptr.String())
if strings.HasSuffix(ptr.Ptr, ".") {
ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1]
}
return ptr.Ptr
}
// Wait for a signal and then synchronously resolve hostname by IP address
// Add the hostname:IP pair to "Clients" array
func asyncRDNSLoop() {
for {
var ip string
ip = <-dnsctx.rdnsChannel
host := resolveRDNS(ip)
if len(host) == 0 {
continue
}
dnsctx.rdnsLock.Lock()
delete(dnsctx.rdnsIP, ip)
dnsctx.rdnsLock.Unlock()
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
}
}
func onDNSRequest(d *proxy.DNSContext) {
qType := d.Req.Question[0].Qtype
if qType != dns.TypeA && qType != dns.TypeAAAA {
return
}
ip := dnsforward.GetIPString(d.Addr)
if ip == "" {
// This would be quite weird if we get here
return
}
beginAsyncRDNS(ip)
}
func generateServerConfig() dnsforward.ServerConfig {
filters := []dnsfilter.Filter{}
userFilter := userFilter()
filters = append(filters, dnsfilter.Filter{
ID: userFilter.ID,
Data: userFilter.Data,
})
for _, filter := range config.Filters {
filters = append(filters, dnsfilter.Filter{
ID: filter.ID,
Data: filter.Data,
})
}
newconfig := dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig,
Filters: filters,
}
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
newconfig.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
if config.TLS.Enabled {
newconfig.TLSConfig = config.TLS.TLSConfig
if config.TLS.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.TLS.PortDNSOverTLS}
}
}
upstreamConfig, err := proxy.ParseUpstreamsConfig(config.DNS.UpstreamDNS, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
if err != nil {
log.Error("Couldn't get upstreams configuration cause: %s", err)
}
newconfig.Upstreams = upstreamConfig.Upstreams
newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
newconfig.AllServers = config.DNS.AllServers
newconfig.FilterHandler = applyClientSettings
newconfig.OnDNSRequest = onDNSRequest
return newconfig
}
// If a client has his own settings, apply them
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
c, ok := clientFind(clientAddr)
if !ok || !c.UseOwnSettings {
return
}
log.Debug("Using settings for client with IP %s", clientAddr)
setts.FilteringEnabled = c.FilteringEnabled
setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
}
func startDNSServer() error {
if isRunning() {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
newconfig := generateServerConfig()
err := dnsServer.Start(&newconfig)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
top := dnsServer.GetStatsTop()
for k := range top.Clients {
beginAsyncRDNS(k)
}
return nil
}
func reconfigureDNSServer() error {
if !isRunning() {
return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running")
}
config := generateServerConfig()
err := dnsServer.Reconfigure(&config)
if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
}
return nil
}
func stopDNSServer() error {
if !isRunning() {
return fmt.Errorf("Refusing to stop forwarding DNS server: not running")
}
err := dnsServer.Stop()
if err != nil {
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
}
return nil
}

13
home/dns_test.go Normal file
View File

@@ -0,0 +1,13 @@
package home
import (
"testing"
)
func TestResolveRDNS(t *testing.T) {
config.DNS.BindHost = "1.1.1.1"
initDNSServer(".")
if r := resolveRDNS("1.1.1.1"); r != "one.one.one.one" {
t.Errorf("resolveRDNS(): %s", r)
}
}

406
home/filter.go Normal file
View File

@@ -0,0 +1,406 @@
package home
import (
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
)
var (
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
)
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name" yaml:"name"`
RulesCount int `json:"rulesCount" yaml:"-"`
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"-"`
checksum uint32 // checksum of the file data
dnsfilter.Filter `yaml:",inline"`
}
// Creates a helper object for working with the user rules
func userFilter() filter {
f := filter{
// User filter always has constant ID=0
Enabled: true,
}
f.Filter.Data = []byte(strings.Join(config.UserRules, "\n"))
return f
}
// Enable or disable a filter
func filterEnable(url string, enable bool) bool {
r := false
config.Lock()
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
if filter.URL == url {
filter.Enabled = enable
if enable {
e := filter.load()
if e != nil {
// This isn't a fatal error,
// because it may occur when someone removes the file from disk.
// In this case the periodic update task will try to download the file.
filter.LastUpdated = time.Time{}
log.Tracef("%s filter load: %v", url, e)
}
} else {
filter.unload()
}
r = true
break
}
}
config.Unlock()
return r
}
// Return TRUE if a filter with this URL exists
func filterExists(url string) bool {
r := false
config.RLock()
for i := range config.Filters {
if config.Filters[i].URL == url {
r = true
break
}
}
config.RUnlock()
return r
}
// Add a filter
// Return FALSE if a filter with this URL exists
func filterAdd(f filter) bool {
config.Lock()
// Check for duplicates
for i := range config.Filters {
if config.Filters[i].URL == f.URL {
config.Unlock()
return false
}
}
config.Filters = append(config.Filters, f)
config.Unlock()
return true
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
func loadFilters() {
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled {
// No need to load a filter that is not enabled
continue
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Debug("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
}
func deduplicateFilters() {
// Deduplicate filters
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID++
return value
}
// Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() {
for range time.Tick(time.Minute) {
refreshFiltersIfNecessary(false)
}
}
// Checks filters updates if necessary
// If force is true, it ignores the filter.LastUpdated field value
//
// Algorithm:
// . Get the list of filters to be updated
// . For each filter run the download and checksum check operation
// . If filter data hasn't changed, set new update time
// . If filter data has changed, parse it, save it on disk, set new update time
// . Apply changes to the current configuration
// . Restart server
func refreshFiltersIfNecessary(force bool) int {
var updateFilters []filter
if config.firstRun {
return 0
}
config.RLock()
for i := range config.Filters {
f := &config.Filters[i] // otherwise we will be operating on a copy
if !f.Enabled {
continue
}
if !force && time.Since(f.LastUpdated) <= updatePeriod {
continue
}
var uf filter
uf.ID = f.ID
uf.URL = f.URL
uf.Name = f.Name
uf.checksum = f.checksum
updateFilters = append(updateFilters, uf)
}
config.RUnlock()
updateCount := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated, err := uf.update()
if err != nil {
log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
continue
}
if updated {
// Saving it to the filters dir now
err = uf.save()
if err != nil {
log.Printf("Failed to save the updated filter %d: %s", uf.ID, err)
continue
}
} else {
mtime := time.Now()
e := os.Chtimes(uf.Path(), mtime, mtime)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
uf.LastUpdated = mtime
}
config.Lock()
for k := range config.Filters {
f := &config.Filters[k]
if f.ID != uf.ID || f.URL != uf.URL {
continue
}
f.LastUpdated = uf.LastUpdated
if !updated {
continue
}
log.Info("Updated filter #%d. Rules: %d -> %d",
f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name
f.Data = uf.Data
f.RulesCount = uf.RulesCount
f.checksum = uf.checksum
updateCount++
}
config.Unlock()
}
if updateCount > 0 && isRunning() {
err := reconfigureDNSServer()
if err != nil {
msg := fmt.Sprintf("SHOULD NOT HAPPEN: cannot reconfigure DNS server with the new filters: %s", err)
panic(msg)
}
}
return updateCount
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func parseFilterContents(contents []byte) (int, string) {
lines := strings.Split(string(contents), "\n")
rulesCount := 0
name := ""
seenTitle := false
// Count lines in the filter
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
if line[0] == '!' {
m := filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else {
rulesCount++
}
}
return rulesCount, name
}
// Perform upgrade on a filter
func (filter *filter) update() (bool, error) {
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
resp, err := client.Get(filter.URL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return false, err
}
if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
contentType := strings.ToLower(resp.Header.Get("content-type"))
if !strings.HasPrefix(contentType, "text/plain") {
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
return false, fmt.Errorf("non-text response %s", contentType)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return false, err
}
// Check if the filter has been really changed
checksum := crc32.ChecksumIEEE(body)
if filter.checksum == checksum {
log.Tracef("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL)
return false, nil
}
// Extract filter name and count number of rules
rulesCount, filterName := parseFilterContents(body)
log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount)
if filterName != "" {
filter.Name = filterName
}
filter.RulesCount = rulesCount
filter.Data = body
filter.checksum = checksum
return true, nil
}
// saves filter contents to the file in dataDir
func (filter *filter) save() error {
filterFilePath := filter.Path()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
err := file.SafeWrite(filterFilePath, filter.Data)
// update LastUpdated field after saving the file
filter.LastUpdated = filter.LastTimeUpdated()
return err
}
// loads filter contents from the file in dataDir
func (filter *filter) load() error {
filterFilePath := filter.Path()
log.Tracef("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err
}
filterFileContents, err := ioutil.ReadFile(filterFilePath)
if err != nil {
return err
}
log.Tracef("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents))
rulesCount, _ := parseFilterContents(filterFileContents)
filter.RulesCount = rulesCount
filter.Data = filterFileContents
filter.checksum = crc32.ChecksumIEEE(filterFileContents)
filter.LastUpdated = filter.LastTimeUpdated()
return nil
}
// Clear filter rules
func (filter *filter) unload() {
filter.Data = nil
filter.RulesCount = 0
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(config.ourWorkingDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
// LastTimeUpdated returns the time when the filter was last time updated
func (filter *filter) LastTimeUpdated() time.Time {
filterFilePath := filter.Path()
s, err := os.Stat(filterFilePath)
if os.IsNotExist(err) {
// if the filter file does not exist, return 0001-01-01
return time.Time{}
}
if err != nil {
// if the filter file does not exist, return 0001-01-01
return time.Time{}
}
// filter file modified time
return s.ModTime()
}

405
home/helpers.go Normal file
View File

@@ -0,0 +1,405 @@
package home
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
// ----------------------------------
// helper functions for HTTP handlers
// ----------------------------------
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != method {
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
if method == "POST" || method == "PUT" || method == "DELETE" {
controlLock.Lock()
defer controlLock.Unlock()
}
handler(w, r)
}
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("POST", handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("GET", handler)
}
// Bridge between http.Handler object and Go function
type httpHandler struct {
handler func(http.ResponseWriter, *http.Request)
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r)
}
func ensureGETHandler(handler func(http.ResponseWriter, *http.Request)) http.Handler {
h := httpHandler{}
h.handler = ensure("GET", handler)
return &h
}
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if config.AuthName == "" || config.AuthPass == "" {
handler(w, r)
return
}
user, pass, ok := r.BasicAuth()
if !ok || user != config.AuthName || pass != config.AuthPass {
w.Header().Set("WWW-Authenticate", `Basic realm="dnsfilter"`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorised.\n"))
return
}
handler(w, r)
}
}
type authHandler struct {
handler http.Handler
}
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
optionalAuth(a.handler.ServeHTTP)(w, r)
}
func optionalAuthHandler(handler http.Handler) http.Handler {
return &authHandler{handler}
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := config.ourConfigFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename)
}
_, err := os.Stat(configfile)
if !os.IsNotExist(err) {
// do nothing, file exists
return false
}
return true
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !config.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
// it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if config.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
return
}
// enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// no port in host
host = r.Host
}
// construct new URL to redirect to
newURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r)
}
}
type postInstallHandlerStruct struct {
handler http.Handler
}
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
}
// -------------------------------------------------
// helper functions for parsing parameters from body
// -------------------------------------------------
func parseParametersFromBody(r io.Reader) (map[string]string, error) {
parameters := map[string]string{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 {
// skip empty lines
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return parameters, errors.New("Got invalid request body")
}
parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
return parameters, nil
}
// ------------------
// network interfaces
// ------------------
type netInterface struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Flags string `json:"flags"`
}
// getValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
// invalid interface is a ppp interface or the one that doesn't allow broadcasts
func getValidNetInterfaces() ([]net.Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
}
netIfaces := []net.Interface{}
for i := range ifaces {
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
// this interface is ppp, we're not interested in this one
continue
}
iface := ifaces[i]
netIfaces = append(netIfaces, iface)
}
return netIfaces, nil
}
// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only
// we do not return link-local addresses here
func getValidNetInterfacesForWeb() ([]netInterface, error) {
ifaces, err := getValidNetInterfaces()
if err != nil {
return nil, errorx.Decorate(err, "Couldn't get interfaces")
}
if len(ifaces) == 0 {
return nil, errors.New("couldn't find any legible interface")
}
var netInterfaces []netInterface
for _, iface := range ifaces {
addrs, e := iface.Addrs()
if e != nil {
return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name)
}
netIface := netInterface{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr.String(),
}
if iface.Flags != 0 {
netIface.Flags = iface.Flags.String()
}
// we don't want link-local addresses in json, so skip them
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
return nil, fmt.Errorf("SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
}
// ignore link-local
if ipnet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipnet.IP.String())
}
if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface)
}
}
return netInterfaces, nil
}
// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily
func checkPortAvailable(host string, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
ln.Close()
return nil
}
func checkPacketPortAvailable(host string, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
ln.Close()
return err
}
// Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
resolverAddr := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
var firstErr error
firstErr = nil
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return con, err
}
return nil, firstErr
}
// check if error is "address already in use"
func errorIsAddrInUse(err error) bool {
errOpError, ok := err.(*net.OpError)
if !ok {
return false
}
errSyscallError, ok := errOpError.Err.(*os.SyscallError)
if !ok {
return false
}
errErrno, ok := errSyscallError.Err.(syscall.Errno)
if !ok {
return false
}
if runtime.GOOS == "windows" {
const WSAEADDRINUSE = 10048
return errErrno == WSAEADDRINUSE
}
return errErrno == syscall.EADDRINUSE
}
// ---------------------
// debug logging helpers
// ---------------------
func _Func() string {
pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(2, pc)
f := runtime.FuncForPC(pc[0])
return path.Base(f.Name())
}
// Parse input string and return IPv4 address
func parseIPv4(s string) net.IP {
ip := net.ParseIP(s)
if ip == nil {
return nil
}
v4InV6Prefix := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
if !bytes.Equal(ip[0:12], v4InV6Prefix) {
return nil
}
return ip.To4()
}

25
home/helpers_test.go Normal file
View File

@@ -0,0 +1,25 @@
package home
import (
"testing"
"github.com/AdguardTeam/golibs/log"
)
func TestGetValidNetInterfacesForWeb(t *testing.T) {
ifaces, err := getValidNetInterfacesForWeb()
if err != nil {
t.Fatalf("Cannot get net interfaces: %s", err)
}
if len(ifaces) == 0 {
t.Fatalf("No net interfaces found")
}
for _, iface := range ifaces {
if len(iface.Addresses) == 0 {
t.Fatalf("No addresses found for %s", iface.Name)
}
log.Printf("%v", iface)
}
}

542
home/home.go Normal file
View File

@@ -0,0 +1,542 @@
package home
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/gobuffalo/packr"
)
// VersionString will be set through ldflags, contains current version
var VersionString = "undefined"
// updateChannel can be set via ldflags
var updateChannel = "release"
var versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
const versionCheckPeriod = time.Hour * 8
var httpServer *http.Server
var httpsServer struct {
server *http.Server
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
sync.Mutex // protects config.TLS
shutdown bool // if TRUE, don't restart the server
}
var pidFileName string // PID file name. Empty if no PID file was created.
const (
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
configSyslog = "syslog"
)
// main is the entry point
func Main() {
// config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib
args := loadOptions()
if args.serviceControlAction != "" {
handleServiceControlAction(args.serviceControlAction)
return
}
// run the protection
run(args)
}
// run initializes configuration and runs the AdGuard Home
// run is a blocking method and it won't exit until the service is stopped!
func run(args options) {
// config file path can be overridden by command-line arguments:
if args.configFilename != "" {
config.ourConfigFilename = args.configFilename
}
// configure working dir and config path
initWorkingDir(args)
// configure log level and output
configureLogger(args)
// enable TLS 1.3
enableTLS13()
// print the first message after logger is configured
log.Printf("AdGuard Home, version %s, channel %s\n", VersionString, updateChannel)
log.Debug("Current working directory is %s", config.ourWorkingDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
}
config.runningAsService = args.runningAsService
config.disableUpdate = args.disableUpdate
config.firstRun = detectFirstRun()
if config.firstRun {
requireAdminRights()
}
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
clientsInit()
if !config.firstRun {
// Do the upgrade if necessary
err := upgradeConfig()
if err != nil {
log.Fatal(err)
}
err = parseConfig()
if err != nil {
os.Exit(1)
}
if args.checkConfig {
log.Info("Configuration file is OK")
os.Exit(0)
}
}
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
config.RlimitNoFile != 0 {
setRlimit(config.RlimitNoFile)
}
// override bind host/port from the console
if args.bindHost != "" {
config.BindHost = args.bindHost
}
if args.bindPort != 0 {
config.BindPort = args.bindPort
}
loadFilters()
if !config.firstRun {
// Save the updated config
err := config.write()
if err != nil {
log.Fatal(err)
}
}
// Init the DNS server instance before registering HTTP handlers
dnsBaseDir := filepath.Join(config.ourWorkingDir, dataDir)
initDNSServer(dnsBaseDir)
if !config.firstRun {
err := startDNSServer()
if err != nil {
log.Fatal(err)
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
}
}
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
pidFileName = args.pidFile
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
refreshFiltersIfNecessary(false)
}()
// Schedule automatic filters updates
go periodicallyRefreshFilters()
// Initialize and run the admin Web interface
box := packr.NewBox("../build/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
registerControlHandlers()
// add handlers for /install paths, we only need them when we're not configured yet
if config.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerInstallHandlers()
}
httpsServer.cond = sync.NewCond(&httpsServer.Mutex)
// for https, we have a separate goroutine loop
go httpServerLoop()
// this loop is used as an ability to change listening host and/or port
for !httpsServer.shutdown {
printHTTPAddresses("http")
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
httpServer = &http.Server{
Addr: address,
}
err := httpServer.ListenAndServe()
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
}
// wait indefinitely for other go-routines to complete their job
select {}
}
func httpServerLoop() {
for !httpsServer.shutdown {
httpsServer.cond.L.Lock()
// this mechanism doesn't let us through until all conditions are met
for config.TLS.Enabled == false ||
config.TLS.PortHTTPS == 0 ||
config.TLS.PrivateKey == "" ||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
httpsServer.cond.Wait()
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
// validate current TLS config and update warnings (it could have been loaded from file)
data := validateCertificates(config.TLS.CertificateChain, config.TLS.PrivateKey, config.TLS.ServerName)
if !data.ValidPair {
cleanupAlways()
log.Fatal(data.WarningValidation)
}
config.Lock()
config.TLS.tlsConfigStatus = data // update warnings
config.Unlock()
// prepare certs for HTTPS server
// important -- they have to be copies, otherwise changing the contents in config.TLS will break encryption for in-flight requests
certchain := make([]byte, len(config.TLS.CertificateChain))
copy(certchain, []byte(config.TLS.CertificateChain))
privatekey := make([]byte, len(config.TLS.PrivateKey))
copy(privatekey, []byte(config.TLS.PrivateKey))
cert, err := tls.X509KeyPair(certchain, privatekey)
if err != nil {
cleanupAlways()
log.Fatal(err)
}
httpsServer.cond.L.Unlock()
// prepare HTTPS server
httpsServer.server = &http.Server{
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
},
}
printHTTPAddresses("https")
err = httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
}
}
}
// Check if the current user has root (administrator) rights
// and if not, ask and try to run as root
func requireAdminRights() {
admin, _ := haveAdminRights()
if admin {
return
}
if runtime.GOOS == "windows" {
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
} else {
log.Error("This is the first launch of AdGuard Home. You must run it as root.")
_, _ = io.WriteString(os.Stdout, "Do you want to start AdGuard Home as root user? [y/n] ")
stdin := bufio.NewReader(os.Stdin)
buf, _ := stdin.ReadString('\n')
buf = strings.TrimSpace(buf)
if buf != "y" {
os.Exit(1)
}
cmd := exec.Command("sudo", os.Args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
_ = cmd.Run()
os.Exit(1)
}
}
// Write PID to a file
func writePIDFile(fn string) bool {
data := fmt.Sprintf("%d", os.Getpid())
err := ioutil.WriteFile(fn, []byte(data), 0644)
if err != nil {
log.Error("Couldn't write PID to file %s: %v", fn, err)
return false
}
return true
}
// initWorkingDir initializes the ourWorkingDir
// if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(args options) {
exec, err := os.Executable()
if err != nil {
panic(err)
}
if args.workDir != "" {
// If there is a custom config file, use it's directory as our working dir
config.ourWorkingDir = args.workDir
} else {
config.ourWorkingDir = filepath.Dir(exec)
}
}
// configureLogger configures logger level and output
func configureLogger(args options) {
ls := getLogSettings()
// command-line arguments can override config settings
if args.verbose {
ls.Verbose = true
}
if args.logFile != "" {
ls.LogFile = args.logFile
}
level := log.INFO
if ls.Verbose {
level = log.DEBUG
}
log.SetLevel(level)
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing else is configured
// Otherwise, we'll simply loose the log output
ls.LogFile = configSyslog
}
if ls.LogFile == "" {
return
}
if ls.LogFile == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := configureSyslog()
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile
}
file, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
log.Fatalf("cannot create a log file: %s", err)
}
log.SetOutput(file)
}
}
// TODO after GO 1.13 release TLS 1.3 will be enabled by default. Remove this afterward
func enableTLS13() {
err := os.Setenv("GODEBUG", os.Getenv("GODEBUG")+",tls13=1")
if err != nil {
log.Fatalf("Failed to enable TLS 1.3: %s", err)
}
}
func cleanup() {
log.Info("Stopping AdGuard Home")
err := stopDNSServer()
if err != nil {
log.Error("Couldn't stop DNS server: %s", err)
}
err = stopDHCPServer()
if err != nil {
log.Error("Couldn't stop DHCP server: %s", err)
}
}
// Stop HTTP server, possibly waiting for all active connections to be closed
func stopHTTPServer() {
httpsServer.shutdown = true
if httpsServer.server != nil {
httpsServer.server.Shutdown(context.TODO())
}
httpServer.Shutdown(context.TODO())
}
// This function is called before application exits
func cleanupAlways() {
if len(pidFileName) != 0 {
os.Remove(pidFileName)
}
log.Info("Stopped")
}
// command-line arguments
type options struct {
verbose bool // is verbose logging enabled
configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog
bindHost string // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
checkConfig bool // Check configuration and exit
disableUpdate bool // If set, don't check for updates
// service control action (see service.ControlAction array + "status" command)
serviceControlAction string
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
}
// loadOptions reads command line arguments and initializes configuration
func loadOptions() options {
o := options{}
var printHelp func()
var opts = []struct {
longName string
shortName string
description string
callbackWithValue func(value string)
callbackNoValue func()
}{
{"config", "c", "Path to the config file", func(value string) { o.configFilename = value }, nil},
{"work-dir", "w", "Path to the working directory", func(value string) { o.workDir = value }, nil},
{"host", "h", "Host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
{"port", "p", "Port to serve HTTP pages on", func(value string) {
v, err := strconv.Atoi(value)
if err != nil {
panic("Got port that is not a number")
}
o.bindPort = v
}, nil},
{"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) {
o.serviceControlAction = value
}, nil},
{"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) {
o.logFile = value
}, nil},
{"pidfile", "", "Path to a file where PID is stored", func(value string) { o.pidFile = value }, nil},
{"check-config", "", "Check configuration and exit", nil, func() { o.checkConfig = true }},
{"no-check-update", "", "Don't check for updates", nil, func() { o.disableUpdate = true }},
{"verbose", "v", "Enable verbose output", nil, func() { o.verbose = true }},
{"help", "", "Print this help", nil, func() {
printHelp()
os.Exit(64)
}},
}
printHelp = func() {
fmt.Printf("Usage:\n\n")
fmt.Printf("%s [options]\n\n", os.Args[0])
fmt.Printf("Options:\n")
for _, opt := range opts {
val := ""
if opt.callbackWithValue != nil {
val = " VALUE"
}
if opt.shortName != "" {
fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName+val, opt.description)
} else {
fmt.Printf(" %-34s %s\n", "--"+opt.longName+val, opt.description)
}
}
}
for i := 1; i < len(os.Args); i++ {
v := os.Args[i]
knownParam := false
for _, opt := range opts {
if v == "--"+opt.longName || (opt.shortName != "" && v == "-"+opt.shortName) {
if opt.callbackWithValue != nil {
if i+1 >= len(os.Args) {
log.Error("Got %s without argument\n", v)
os.Exit(64)
}
i++
opt.callbackWithValue(os.Args[i])
} else if opt.callbackNoValue != nil {
opt.callbackNoValue()
}
knownParam = true
break
}
}
if !knownParam {
log.Error("unknown option %v\n", v)
printHelp()
os.Exit(64)
}
}
return o
}
// prints IP addresses which user can use to open the admin interface
// proto is either "http" or "https"
func printHTTPAddresses(proto string) {
var address string
if proto == "https" && config.TLS.ServerName != "" {
if config.TLS.PortHTTPS == 443 {
log.Printf("Go to https://%s", config.TLS.ServerName)
} else {
log.Printf("Go to https://%s:%d", config.TLS.ServerName, config.TLS.PortHTTPS)
}
} else if config.BindHost == "0.0.0.0" {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := getValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
log.Printf("Go to %s://%s", proto, address)
return
}
for _, iface := range ifaces {
address = net.JoinHostPort(iface.Addresses[0], strconv.Itoa(config.BindPort))
log.Printf("Go to %s://%s", proto, address)
}
} else {
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
log.Printf("Go to %s://%s", proto, address)
}
}

72
home/i18n.go Normal file
View File

@@ -0,0 +1,72 @@
package home
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/AdguardTeam/golibs/log"
)
// --------------------
// internationalization
// --------------------
var allowedLanguages = map[string]bool{
"en": true,
"ru": true,
"vi": true,
"es": true,
"fr": true,
"ja": true,
"sv": true,
"pt-br": true,
"zh-tw": true,
"bg": true,
"zh-cn": true,
}
func isLanguageAllowed(language string) bool {
l := strings.ToLower(language)
return allowedLanguages[l]
}
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
}
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
language := strings.TrimSpace(string(body))
if language == "" {
errorText := fmt.Sprintf("empty language specified")
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
if !isLanguageAllowed(language) {
errorText := fmt.Sprintf("unknown language specified: %s", language)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
config.Language = language
httpUpdateConfigReloadDNSReturnOK(w, r)
}

27
home/os_unix.go Normal file
View File

@@ -0,0 +1,27 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package home
import (
"os"
"syscall"
"github.com/AdguardTeam/golibs/log"
)
// Set user-specified limit of how many fd's we can use
// https://github.com/AdguardTeam/AdGuardHome/issues/659
func setRlimit(val uint) {
var rlim syscall.Rlimit
rlim.Max = uint64(val)
rlim.Cur = uint64(val)
err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlim)
if err != nil {
log.Error("Setrlimit() failed: %v", err)
}
}
// Check if the current user has root (administrator) rights
func haveAdminRights() (bool, error) {
return os.Getuid() == 0, nil
}

28
home/os_windows.go Normal file
View File

@@ -0,0 +1,28 @@
package home
import "golang.org/x/sys/windows"
// Set user-specified limit of how many fd's we can use
func setRlimit(val uint) {
}
func haveAdminRights() (bool, error) {
var token windows.Token
h, _ := windows.GetCurrentProcess()
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
if err != nil {
return false, err
}
info := make([]byte, 4)
var returnedLen uint32
err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen)
token.Close()
if err != nil {
return false, err
}
if info[0] == 0 {
return false, nil
}
return true, nil
}

185
home/service.go Normal file
View File

@@ -0,0 +1,185 @@
package home
import (
"os"
"runtime"
"github.com/AdguardTeam/golibs/log"
"github.com/kardianos/service"
)
const (
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
serviceName = "AdGuardHome"
serviceDisplayName = "AdGuard Home service"
serviceDescription = "AdGuard Home: Network-level blocker"
)
// Represents the program that will be launched by a service or daemon
type program struct {
}
// Start should quickly start the program
func (p *program) Start(s service.Service) error {
// Start should not block. Do the actual work async.
args := options{runningAsService: true}
go run(args)
return nil
}
// Stop stops the program
func (p *program) Stop(s service.Service) error {
// Stop should not block. Return with a few seconds.
cleanup()
cleanupAlways()
os.Exit(0)
return nil
}
// handleServiceControlAction one of the possible control actions:
// install -- installs a service/daemon
// uninstall -- uninstalls it
// status -- prints the service status
// start -- starts the previously installed service
// stop -- stops the previously installed service
// restart - restarts the previously installed service
// run - this is a special command that is not supposed to be used directly
// it is specified when we register a service, and it indicates to the app
// that it is being run as a service/daemon.
func handleServiceControlAction(action string) {
log.Printf("Service control action: %s", action)
pwd, err := os.Getwd()
if err != nil {
log.Fatal("Unable to find the path to the current directory")
}
svcConfig := &service.Config{
Name: serviceName,
DisplayName: serviceDisplayName,
Description: serviceDescription,
WorkingDirectory: pwd,
Arguments: []string{"-s", "run"},
}
configureService(svcConfig)
prg := &program{}
s, err := service.New(prg, svcConfig)
if err != nil {
log.Fatal(err)
}
if action == "status" {
status, errSt := s.Status()
if errSt != nil {
log.Fatalf("failed to get service status: %s", errSt)
}
switch status {
case service.StatusUnknown:
log.Printf("Service status is unknown")
case service.StatusStopped:
log.Printf("Service is stopped")
case service.StatusRunning:
log.Printf("Service is running")
}
} else if action == "run" {
err = s.Run()
if err != nil {
log.Fatalf("Failed to run service: %s", err)
}
} else {
if action == "uninstall" {
// In case of Windows and Linux when a running service is being uninstalled,
// it is just marked for deletion but not stopped
// So we explicitly stop it here
_ = s.Stop()
}
err = service.Control(s, action)
if err != nil {
log.Fatal(err)
}
log.Printf("Action %s has been done successfully on %s", action, service.ChosenSystem().String())
if action == "install" {
// Start automatically after install
err = service.Control(s, "start")
if err != nil {
log.Fatalf("Failed to start the service: %s", err)
}
log.Printf("Service has been started")
if detectFirstRun() {
log.Printf(`Almost ready!
AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.`)
printHTTPAddresses("http")
}
} else if action == "uninstall" {
cleanupService()
}
}
}
// configureService defines additional settings of the service
func configureService(c *service.Config) {
c.Option = service.KeyValue{}
// OS X
// Redefines the launchd config file template
// The purpose is to enable stdout/stderr redirect by default
c.Option["LaunchdConfig"] = launchdConfig
// This key is used to start the job as soon as it has been loaded. For daemons this means execution at boot time, for agents execution at login.
c.Option["RunAtLoad"] = true
// POSIX
// Redirect StdErr & StdOut to files.
c.Option["LogOutput"] = true
}
// cleanupService called on the service uninstall, cleans up additional files if needed
func cleanupService() {
if runtime.GOOS == "darwin" {
// Removing log files on cleanup and ignore errors
err := os.Remove(launchdStdoutPath)
if err != nil && !os.IsNotExist(err) {
log.Printf("cannot remove %s", launchdStdoutPath)
}
err = os.Remove(launchdStderrPath)
if err != nil && !os.IsNotExist(err) {
log.Printf("cannot remove %s", launchdStderrPath)
}
}
}
// Basically the same template as the one defined in github.com/kardianos/service
// but with two additional keys - StandardOutPath and StandardErrorPath
var launchdConfig = `<?xml version='1.0' encoding='UTF-8'?>
<!DOCTYPE plist PUBLIC "-//Apple Computer//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd" >
<plist version='1.0'>
<dict>
<key>Label</key><string>{{html .Name}}</string>
<key>ProgramArguments</key>
<array>
<string>{{html .Path}}</string>
{{range .Config.Arguments}}
<string>{{html .}}</string>
{{end}}
</array>
{{if .UserName}}<key>UserName</key><string>{{html .UserName}}</string>{{end}}
{{if .ChRoot}}<key>RootDirectory</key><string>{{html .ChRoot}}</string>{{end}}
{{if .WorkingDirectory}}<key>WorkingDirectory</key><string>{{html .WorkingDirectory}}</string>{{end}}
<key>SessionCreate</key><{{bool .SessionCreate}}/>
<key>KeepAlive</key><{{bool .KeepAlive}}/>
<key>RunAtLoad</key><{{bool .RunAtLoad}}/>
<key>Disabled</key><false/>
<key>StandardOutPath</key>
<string>` + launchdStdoutPath + `</string>
<key>StandardErrorPath</key>
<string>` + launchdStderrPath + `</string>
</dict>
</plist>
`

18
home/syslog_others.go Normal file
View File

@@ -0,0 +1,18 @@
// +build !windows,!nacl,!plan9
package home
import (
"log"
"log/syslog"
)
// configureSyslog reroutes standard logger output to syslog
func configureSyslog() error {
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
if err != nil {
return err
}
log.SetOutput(w)
return nil
}

39
home/syslog_windows.go Normal file
View File

@@ -0,0 +1,39 @@
package home
import (
"log"
"strings"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/eventlog"
)
type eventLogWriter struct {
el *eventlog.Log
}
// Write sends a log message to the Event Log.
func (w *eventLogWriter) Write(b []byte) (int, error) {
return len(b), w.el.Info(1, string(b))
}
func configureSyslog() error {
// Note that the eventlog src is the same as the service name
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
// Continue if we receive "registry key already exists" or if we get
// ERROR_ACCESS_DENIED so that we can log without administrative permissions
// for pre-existing eventlog sources.
if err := eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error); err != nil {
if !strings.Contains(err.Error(), "registry key already exists") && err != windows.ERROR_ACCESS_DENIED {
return err
}
}
el, err := eventlog.Open(serviceName)
if err != nil {
return err
}
log.SetOutput(&eventLogWriter{el: el})
return nil
}

194
home/upgrade.go Normal file
View File

@@ -0,0 +1,194 @@
package home
import (
"fmt"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2"
)
const currentSchemaVersion = 3 // used for upgrading from old configs to new config
// Performs necessary upgrade operations if needed
func upgradeConfig() error {
// read a config file into an interface map, so we can manipulate values without losing any
diskConfig := map[string]interface{}{}
body, err := readConfigFile()
if err != nil {
return err
}
err = yaml.Unmarshal(body, &diskConfig)
if err != nil {
log.Printf("Couldn't parse config file: %s", err)
return err
}
schemaVersionInterface, ok := diskConfig["schema_version"]
log.Tracef("got schema version %v", schemaVersionInterface)
if !ok {
// no schema version, set it to 0
schemaVersionInterface = 0
}
schemaVersion, ok := schemaVersionInterface.(int)
if !ok {
err = fmt.Errorf("configuration file contains non-integer schema_version, abort")
log.Println(err)
return err
}
if schemaVersion == currentSchemaVersion {
// do nothing
return nil
}
return upgradeConfigSchema(schemaVersion, &diskConfig)
}
// Upgrade from oldVersion to newVersion
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
switch oldVersion {
case 0:
err := upgradeSchema0to3(diskConfig)
if err != nil {
return err
}
case 1:
err := upgradeSchema1to3(diskConfig)
if err != nil {
return err
}
case 2:
err := upgradeSchema2to3(diskConfig)
if err != nil {
return err
}
default:
err := fmt.Errorf("configuration file contains unknown schema_version, abort")
log.Println(err)
return err
}
configFile := config.getConfigFilename()
body, err := yaml.Marshal(diskConfig)
if err != nil {
log.Printf("Couldn't generate YAML file: %s", err)
return err
}
config.fileData = body
err = file.SafeWrite(configFile, body)
if err != nil {
log.Printf("Couldn't save YAML config: %s", err)
return err
}
return nil
}
// The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt")
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
err = os.Remove(dnsFilterPath)
if err != nil {
log.Printf("Cannot remove %s due to %s", dnsFilterPath, err)
// not fatal, move on
}
}
(*diskConfig)["schema_version"] = 1
return nil
}
// Second schema upgrade:
// coredns is now dns in config
// delete 'Corefile', since we don't use that anymore
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile")
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
err = os.Remove(coreFilePath)
if err != nil {
log.Printf("Cannot remove %s due to %s", coreFilePath, err)
// not fatal, move on
}
}
if _, ok := (*diskConfig)["dns"]; !ok {
(*diskConfig)["dns"] = (*diskConfig)["coredns"]
delete((*diskConfig), "coredns")
}
(*diskConfig)["schema_version"] = 2
return nil
}
// Third schema upgrade:
// Bootstrap DNS becomes an array
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
// Let's read dns configuration from diskConfig
dnsConfig, ok := (*diskConfig)["dns"]
if !ok {
return fmt.Errorf("no DNS configuration in config file")
}
// Convert interface{} to map[string]interface{}
newDNSConfig := make(map[string]interface{})
switch v := dnsConfig.(type) {
case map[interface{}]interface{}:
for k, v := range v {
newDNSConfig[fmt.Sprint(k)] = v
}
default:
return fmt.Errorf("DNS configuration is not a map")
}
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok {
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
(newDNSConfig)["bootstrap_dns"] = newBootstrapConfig
(*diskConfig)["dns"] = newDNSConfig
} else {
return fmt.Errorf("no bootstrap DNS in DNS config")
}
// Bump schema version
(*diskConfig)["schema_version"] = 3
return nil
}
// jump three schemas at once -- this time we just do it sequentially
func upgradeSchema0to3(diskConfig *map[string]interface{}) error {
err := upgradeSchema0to1(diskConfig)
if err != nil {
return err
}
return upgradeSchema1to3(diskConfig)
}
// jump two schemas at once -- this time we just do it sequentially
func upgradeSchema1to3(diskConfig *map[string]interface{}) error {
err := upgradeSchema1to2(diskConfig)
if err != nil {
return err
}
return upgradeSchema2to3(diskConfig)
}

230
home/upgrade_test.go Normal file
View File

@@ -0,0 +1,230 @@
package home
import (
"fmt"
"testing"
)
func TestUpgrade1to2(t *testing.T) {
// let's create test config for 1 schema version
diskConfig := createTestDiskConfig(1)
// update config
err := upgradeSchema1to2(&diskConfig)
if err != nil {
t.Fatalf("Can't upgrade schema version from 1 to 2")
}
// ensure that schema version was bumped
compareSchemaVersion(t, diskConfig["schema_version"], 2)
// old coredns entry should be removed
_, ok := diskConfig["coredns"]
if ok {
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
}
// pull out new dns config
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
}
// cast dns configurations to maps and compare them
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
newDNSConfig := castInterfaceToMap(t, dnsMap)
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
// exclude dns config and schema version from disk config comparison
oldExcludedEntries := []string{"coredns", "schema_version"}
newExcludedEntries := []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(1)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
}
func TestUpgrade2to3(t *testing.T) {
// let's create test config
diskConfig := createTestDiskConfig(2)
// upgrade schema from 2 to 3
err := upgradeSchema2to3(&diskConfig)
if err != nil {
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
}
// check new schema version
compareSchemaVersion(t, diskConfig["schema_version"], 3)
// pull out new dns configuration
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No dns config in new configuration")
}
// cast dns configuration to map
newDNSConfig := castInterfaceToMap(t, dnsMap)
// check if bootstrap DNS becomes an array
bootstrapDNS := newDNSConfig["bootstrap_dns"]
switch v := bootstrapDNS.(type) {
case []string:
if len(v) != 1 {
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
}
if v[0] != "8.8.8.8:53" {
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
}
default:
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
}
// exclude bootstrap DNS from DNS configs comparison
excludedEntries := []string{"bootstrap_dns"}
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
// excluded dns config and schema version from disk config comparison
excludedEntries = []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(2)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
}
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
newConfig = make(map[string]interface{})
switch v := oldConfig.(type) {
case map[interface{}]interface{}:
for key, value := range v {
newConfig[fmt.Sprint(key)] = value
}
case map[string]interface{}:
for key, value := range v {
newConfig[key] = value
}
default:
t.Fatalf("DNS configuration is not a map")
}
return
}
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
for _, k := range oldKey {
delete(*oldConfig, k)
}
for _, k := range newKey {
delete(*newConfig, k)
}
compareConfigs(t, oldConfig, newConfig)
}
// compares configs before and after schema upgrade
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
if len(*oldConfig) != len(*newConfig) {
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
}
// Check old and new entries
for k, v := range *newConfig {
switch value := v.(type) {
case string:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
}
case int:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
}
case []string:
for i, line := range value {
if len((*oldConfig)[k].([]string)) != len(value) {
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
}
if (*oldConfig)[k].([]string)[i] != line {
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
}
}
case bool:
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
default:
t.Fatalf("uknown data type for %s: %T", k, value)
}
}
}
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
switch v := newSchemaVersion.(type) {
case int:
if v != schemaVersion {
t.Fatalf("Wrong schema version in new config file")
}
default:
t.Fatalf("Schema version is not an integer after update")
}
}
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,
},
{
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RulesCount: 200,
},
}
diskConfig["user_rules"] = []string{}
diskConfig["schema_version"] = schemaVersion
diskConfig["bind_host"] = "0.0.0.0"
diskConfig["bind_port"] = 80
diskConfig["auth_name"] = "name"
diskConfig["auth_pass"] = "pass"
dnsConfig := createTestDNSConfig(schemaVersion)
if schemaVersion > 1 {
diskConfig["dns"] = dnsConfig
} else {
diskConfig["coredns"] = dnsConfig
}
return diskConfig
}
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
dnsConfig := make(map[interface{}]interface{})
dnsConfig["port"] = 53
dnsConfig["blocked_response_ttl"] = 10
dnsConfig["querylog_enabled"] = true
dnsConfig["ratelimit"] = 20
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
if schemaVersion > 2 {
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
}
dnsConfig["parental_sensitivity"] = 13
dnsConfig["ratelimit_whitelist"] = []string{}
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
dnsConfig["filtering_enabled"] = true
dnsConfig["refuse_any"] = true
dnsConfig["parental_enabled"] = true
dnsConfig["bind_host"] = "0.0.0.0"
dnsConfig["protection_enabled"] = true
dnsConfig["safesearch_enabled"] = true
dnsConfig["safebrowsing_enabled"] = true
return dnsConfig
}