Merge branch 'master' into 4387-fix-openapi-schema
This commit is contained in:
@@ -2,10 +2,9 @@ package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -56,12 +55,49 @@ type clientSource uint
|
||||
// Client sources. The order determines the priority.
|
||||
const (
|
||||
ClientSourceWHOIS clientSource = iota
|
||||
ClientSourceRDNS
|
||||
ClientSourceARP
|
||||
ClientSourceRDNS
|
||||
ClientSourceDHCP
|
||||
ClientSourceHostsFile
|
||||
)
|
||||
|
||||
var _ fmt.Stringer = clientSource(0)
|
||||
|
||||
// String returns a human-readable name of cs.
|
||||
func (cs clientSource) String() (s string) {
|
||||
switch cs {
|
||||
case ClientSourceWHOIS:
|
||||
return "WHOIS"
|
||||
case ClientSourceARP:
|
||||
return "ARP"
|
||||
case ClientSourceRDNS:
|
||||
return "rDNS"
|
||||
case ClientSourceDHCP:
|
||||
return "DHCP"
|
||||
case ClientSourceHostsFile:
|
||||
return "etc/hosts"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var _ encoding.TextMarshaler = clientSource(0)
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler for the clientSource.
|
||||
func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
return []byte(cs.String()), nil
|
||||
}
|
||||
|
||||
// clientSourceConf is used to configure where the runtime clients will be
|
||||
// obtained from.
|
||||
type clientSourcesConf struct {
|
||||
WHOIS bool `yaml:"whois"`
|
||||
ARP bool `yaml:"arp"`
|
||||
RDNS bool `yaml:"rdns"`
|
||||
DHCP bool `yaml:"dhcp"`
|
||||
HostsFile bool `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// RuntimeClient information
|
||||
type RuntimeClient struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo
|
||||
@@ -99,6 +135,9 @@ type clientsContainer struct {
|
||||
// hosts database.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
// arpdb stores the neighbors retrieved from ARP.
|
||||
arpdb aghnet.ARPDB
|
||||
|
||||
testing bool // if TRUE, this object is used for internal tests
|
||||
}
|
||||
|
||||
@@ -109,6 +148,7 @@ func (clients *clientsContainer) Init(
|
||||
objects []*clientObject,
|
||||
dhcpServer *dhcpd.Server,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
@@ -121,6 +161,7 @@ func (clients *clientsContainer) Init(
|
||||
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects)
|
||||
|
||||
if clients.testing {
|
||||
@@ -132,14 +173,14 @@ func (clients *clientsContainer) Init(
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
}
|
||||
|
||||
go clients.handleHostsUpdates()
|
||||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
if clients.etcHosts != nil {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +197,9 @@ func (clients *clientsContainer) Start() {
|
||||
|
||||
// Reload reloads runtime clients.
|
||||
func (clients *clientsContainer) Reload() {
|
||||
clients.addFromSystemARP()
|
||||
if clients.arpdb != nil {
|
||||
clients.addFromSystemARP()
|
||||
}
|
||||
}
|
||||
|
||||
type clientObject struct {
|
||||
@@ -255,6 +298,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) periodicUpdate() {
|
||||
defer log.OnPanic("clients container")
|
||||
|
||||
for {
|
||||
clients.Reload()
|
||||
time.Sleep(clientsUpdatePeriod)
|
||||
@@ -380,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
c.Tags = stringutil.CloneSlice(c.Tags)
|
||||
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return c, true
|
||||
}
|
||||
|
||||
@@ -476,7 +522,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = clients.ipToRC.Get(ip)
|
||||
if !ok {
|
||||
return nil, false
|
||||
@@ -530,7 +576,7 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.IDs[i] = mac.String()
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
c.IDs[i] = id
|
||||
c.IDs[i] = strings.ToLower(id)
|
||||
} else {
|
||||
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
|
||||
}
|
||||
@@ -721,20 +767,18 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
ok = clients.addHostLocked(ip, host, src)
|
||||
|
||||
return ok, nil
|
||||
return clients.addHostLocked(ip, host, src), nil
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
} else {
|
||||
rc = &RuntimeClient{
|
||||
@@ -754,7 +798,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -782,41 +826,38 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
hosts, ok := v.(*stringutil.Set)
|
||||
hosts.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rec, ok := v.(*aghnet.HostsRecord)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
hosts.Range(func(name string) (cont bool) {
|
||||
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
|
||||
n++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
|
||||
n++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
// command.
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
if runtime.GOOS == "windows" {
|
||||
if err := clients.arpdb.Refresh(); err != nil {
|
||||
log.Error("refreshing arp container: %s", err)
|
||||
|
||||
clients.arpdb = aghnet.EmptyARPDB{}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("arp", "-a")
|
||||
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
|
||||
data, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Debug("command %q has failed: %q code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
ns := clients.arpdb.Neighbors()
|
||||
if len(ns) == 0 {
|
||||
log.Debug("refreshing arp container: the update is empty")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -825,36 +866,20 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceARP)
|
||||
|
||||
n := 0
|
||||
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, ln := range lines {
|
||||
lparen := strings.Index(ln, " (")
|
||||
rparen := strings.Index(ln, ") ")
|
||||
if lparen == -1 || rparen == -1 || lparen >= rparen {
|
||||
continue
|
||||
}
|
||||
|
||||
host := ln[:lparen]
|
||||
ipStr := ln[lparen+2 : rparen]
|
||||
ip := net.ParseIP(ipStr)
|
||||
if netutil.ValidateDomainName(host) != nil || ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(ip, host, ClientSourceARP)
|
||||
if ok {
|
||||
n++
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
}
|
||||
|
||||
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
|
||||
// server.
|
||||
func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
if clients.dhcpServer == nil {
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestClients(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
c := &Client{
|
||||
@@ -194,7 +194,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil)
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
whois := &RuntimeClientWHOISInfo{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
@@ -253,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil)
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
@@ -332,7 +332,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil)
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
||||
@@ -47,9 +47,9 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -70,7 +70,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
cj := runtimeClientJSON{
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
|
||||
Name: rc.Host,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
switch rc.Source {
|
||||
case ClientSourceDHCP:
|
||||
cj.Source = "DHCP"
|
||||
case ClientSourceRDNS:
|
||||
cj.Source = "rDNS"
|
||||
case ClientSourceARP:
|
||||
cj.Source = "ARP"
|
||||
case ClientSourceWHOIS:
|
||||
cj.Source = "WHOIS"
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
@@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to encode to json: %v",
|
||||
e,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,15 +28,36 @@ const (
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// logSettings
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
// TODO(a.garipov): Put them into a separate object.
|
||||
type logSettings struct {
|
||||
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
|
||||
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
|
||||
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
|
||||
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
|
||||
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
|
||||
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
// File is the path to the log file. If empty, logs are written to stdout.
|
||||
// If "syslog", logs are written to syslog.
|
||||
File string `yaml:"log_file"`
|
||||
|
||||
// MaxBackups is the maximum number of old log files to retain.
|
||||
//
|
||||
// NOTE: MaxAge may still cause them to get deleted.
|
||||
MaxBackups int `yaml:"log_max_backups"`
|
||||
|
||||
// MaxSize is the maximum size of the log file before it gets rotated, in
|
||||
// megabytes. The default value is 100 MB.
|
||||
MaxSize int `yaml:"log_max_size"`
|
||||
|
||||
// MaxAge is the maximum duration for retaining old log files, in days.
|
||||
MaxAge int `yaml:"log_max_age"`
|
||||
|
||||
// Compress determines, if the rotated log files should be compressed using
|
||||
// gzip.
|
||||
Compress bool `yaml:"log_compress"`
|
||||
|
||||
// LocalTime determines, if the time used for formatting the timestamps in
|
||||
// is the computer's local time.
|
||||
LocalTime bool `yaml:"log_localtime"`
|
||||
|
||||
// Verbose determines, if verbose (aka debug) logging is enabled.
|
||||
Verbose bool `yaml:"verbose"`
|
||||
}
|
||||
|
||||
// osConfig contains OS-related configuration.
|
||||
@@ -51,6 +73,13 @@ type osConfig struct {
|
||||
RlimitNoFile uint64 `yaml:"rlimit_nofile"`
|
||||
}
|
||||
|
||||
type clientsConfig struct {
|
||||
// Sources defines the set of sources to fetch the runtime clients from.
|
||||
Sources *clientSourcesConf `yaml:"runtime_sources"`
|
||||
// Persistent are the configured clients.
|
||||
Persistent []*clientObject `yaml:"persistent"`
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type configuration struct {
|
||||
@@ -88,7 +117,7 @@ type configuration struct {
|
||||
// Clients contains the YAML representations of the persistent clients.
|
||||
// This field is only used for reading and writing persistent client data.
|
||||
// Keep this field sorted to ensure consistent ordering.
|
||||
Clients []*clientObject `yaml:"clients"`
|
||||
Clients *clientsConfig `yaml:"clients"`
|
||||
|
||||
logSettings `yaml:",inline"`
|
||||
|
||||
@@ -123,8 +152,9 @@ type dnsConfig struct {
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
|
||||
// ResolveClients enables and disables resolving clients with RDNS.
|
||||
ResolveClients bool `yaml:"resolve_clients"`
|
||||
// PrivateNets is the set of IP networks for which the private reverse DNS
|
||||
// resolver should be used.
|
||||
PrivateNets []string `yaml:"private_networks"`
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
@@ -179,6 +209,7 @@ var config = &configuration{
|
||||
Ratelimit: 20,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
HandleDDR: true,
|
||||
FastestTimeout: timeutil.Duration{
|
||||
Duration: fastip.DefaultPingWaitTimeout,
|
||||
},
|
||||
@@ -194,7 +225,6 @@ var config = &configuration{
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
ResolveClients: true,
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
@@ -205,12 +235,21 @@ var config = &configuration{
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
},
|
||||
Clients: &clientsConfig{
|
||||
Sources: &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: true,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
logSettings: logSettings{
|
||||
LogCompress: false,
|
||||
LogLocalTime: false,
|
||||
LogMaxBackups: 0,
|
||||
LogMaxSize: 100,
|
||||
LogMaxAge: 3,
|
||||
Compress: false,
|
||||
LocalTime: false,
|
||||
MaxBackups: 0,
|
||||
MaxSize: 100,
|
||||
MaxAge: 3,
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
@@ -285,25 +324,28 @@ func parseConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
config.BindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Consider adding a udpPort with the same value when
|
||||
// we add support for HTTP/3 for web admin interface.
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
@@ -317,8 +359,14 @@ func parseConfig() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPorts is a helper for ports validation. It skips zero ports.
|
||||
func addPorts(uc aghalg.UniqChecker, ports ...int) {
|
||||
// udpPort is the port number for UDP protocol.
|
||||
type udpPort int
|
||||
|
||||
// tcpPort is the port number for TCP protocol.
|
||||
type tcpPort int
|
||||
|
||||
// addPorts is a helper for ports validation that skips zero ports.
|
||||
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
|
||||
for _, p := range ports {
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
@@ -340,13 +388,14 @@ func readConfigFile() (fileData []byte, err error) {
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() error {
|
||||
func (c *configuration) write() (err error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
@@ -380,9 +429,7 @@ func (c *configuration) write() error {
|
||||
s.WriteDiskConfig(&c)
|
||||
dns := &config.DNS
|
||||
dns.FilteringConfig = c
|
||||
dns.LocalPTRResolvers,
|
||||
dns.ResolveClients,
|
||||
dns.UsePrivateRDNS = s.RDNSSettings()
|
||||
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
@@ -391,22 +438,23 @@ func (c *configuration) write() error {
|
||||
config.DHCP = c
|
||||
}
|
||||
|
||||
config.Clients = Context.clients.forConfig()
|
||||
config.Clients.Persistent = Context.clients.forConfig()
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
log.Debug("writing config file %q", configFile)
|
||||
|
||||
return err
|
||||
buf := &bytes.Buffer{}
|
||||
enc := yaml.NewEncoder(buf)
|
||||
enc.SetIndent(2)
|
||||
|
||||
err = enc.Encode(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating config file: %w", err)
|
||||
}
|
||||
|
||||
err = maybe.WriteFile(configFile, yamlText, 0o644)
|
||||
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
|
||||
if err != nil {
|
||||
log.Error("Couldn't save YAML config: %s", err)
|
||||
|
||||
return err
|
||||
return fmt.Errorf("writing config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -189,7 +189,7 @@ func registerControlHandlers() {
|
||||
RegisterAuthHandlers()
|
||||
}
|
||||
|
||||
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
Context.mux.HandleFunc(url, postInstall(handler))
|
||||
|
||||
@@ -105,19 +105,22 @@ type checkConfResp struct {
|
||||
|
||||
// validateWeb returns error is the web part if the initial configuration can't
|
||||
// be set.
|
||||
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.Web.Port
|
||||
addPorts(uc, config.BetaBindPort, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
// Avoid duplicating the error into the status of DNS.
|
||||
uc[port] = 1
|
||||
portInt := req.Web.Port
|
||||
port := tcpPort(portInt)
|
||||
addPorts(tcpPorts, tcpPort(config.BetaBindPort), port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
// Reset the value for the port to 1 to make sure that validateDNS
|
||||
// doesn't throw the same error, unless the same TCP port is set there
|
||||
// as well.
|
||||
tcpPorts[port] = 1
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
switch port {
|
||||
switch portInt {
|
||||
case 0, config.BindPort:
|
||||
return nil
|
||||
default:
|
||||
@@ -125,21 +128,18 @@ func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, port)
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
// be set. canAutofix is true if the port can be unbound by AdGuard Home
|
||||
// automatically.
|
||||
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
|
||||
func (req *checkConfReq) validateDNS(
|
||||
tcpPorts aghalg.UniqChecker[tcpPort],
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
addPorts(uc, port)
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
@@ -148,6 +148,11 @@ func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, er
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
// Check TCP as well.
|
||||
addPorts(tcpPorts, tcpPort(port))
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -185,13 +190,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
resp := &checkConfResp{}
|
||||
uc := aghalg.UniqChecker{}
|
||||
|
||||
if err = req.validateWeb(uc); err != nil {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
if err = req.validateWeb(tcpPorts); err != nil {
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
@@ -212,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := aghnet.GetInterfaceByIP(ip)
|
||||
interfaceName := aghnet.InterfaceByIP(ip)
|
||||
resp.Static = "no"
|
||||
|
||||
if len(interfaceName) == 0 {
|
||||
|
||||
@@ -3,14 +3,15 @@ package home
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
@@ -27,12 +28,16 @@ type temporaryError interface {
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
resp := &versionResponse{}
|
||||
if Context.disableUpdate {
|
||||
// w.Header().Set("Content-Type", "application/json")
|
||||
resp.Disabled = true
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
// TODO(e.burkov): Add error handling and deal with headers.
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,30 +49,48 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = requestVersionInfo(resp, req.Recheck)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = resp.setAllowedToAutoUpdate()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
// update server.
|
||||
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
for i := 0; i != 3; i++ {
|
||||
func() {
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
|
||||
}()
|
||||
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
|
||||
if err != nil {
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
// Temporary network error. This case may happen while
|
||||
// we're restarting our DNS server. Log and sleep for
|
||||
// some time.
|
||||
// Temporary network error. This case may happen while we're
|
||||
// restarting our DNS server. Log and sleep for some time.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
d := time.Duration(i) * time.Second
|
||||
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
time.Sleep(d)
|
||||
|
||||
continue
|
||||
@@ -76,29 +99,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
// TODO(a.garipov): Figure out the purpose of %T verb.
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadGateway,
|
||||
"Couldn't get version check json from %s: %T %s\n",
|
||||
vcu,
|
||||
err,
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
return fmt.Errorf("getting version info from %s: %s", vcu, err)
|
||||
}
|
||||
|
||||
resp.confirmAutoUpdate()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleUpdate performs an update to the latest available version procedure.
|
||||
@@ -109,7 +117,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.Update()
|
||||
// Retain the current absolute path of the executable, since the updater is
|
||||
// likely to change the position current one to the backup directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4735.
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "getting path: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
@@ -121,85 +140,88 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the underlying functions wrap
|
||||
// it with timeout and shut down the server, which handles current
|
||||
// request. It also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
go func() {
|
||||
finishUpdate(context.Background())
|
||||
}()
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), execPath)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
type versionResponse struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
updater.VersionInfo
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
// confirmAutoUpdate checks the real possibility of auto update.
|
||||
func (vr *versionResponse) confirmAutoUpdate() {
|
||||
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
|
||||
canUpdate := true
|
||||
|
||||
var tlsConf *tlsConfigSettings
|
||||
if runtime.GOOS != "windows" {
|
||||
tlsConf = &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
}
|
||||
|
||||
if tlsConf != nil &&
|
||||
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
|
||||
tlsConf.PortDNSOverTLS < 1024 ||
|
||||
tlsConf.PortDNSOverQUIC < 1024)) ||
|
||||
config.BindPort < 1024 ||
|
||||
config.DNS.Port < 1024) {
|
||||
canUpdate, _ = aghnet.CanBindPrivilegedPorts()
|
||||
}
|
||||
vr.CanAutoUpdate = &canUpdate
|
||||
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
|
||||
// allowed to perform an automatic update by the OS.
|
||||
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
if vr.CanAutoUpdate != aghalg.NBTrue {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
vr.CanAutoUpdate = aghalg.BoolToNullBool(canUpdate)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
||||
// indicates that privileged ports are used.
|
||||
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context) {
|
||||
log.Info("Stopping all tasks")
|
||||
func finishUpdate(ctx context.Context, execPath string) {
|
||||
var err error
|
||||
|
||||
log.Info("stopping all tasks")
|
||||
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
if runtime.GOOS == "windows" {
|
||||
exeName = "AdGuardHome.exe"
|
||||
}
|
||||
curBinName := filepath.Join(Context.workDir, exeName)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if Context.runningAsService {
|
||||
// Note:
|
||||
// we can't restart the service via "kardianos/service" package - it kills the process first
|
||||
// we can't start a new instance - Windows doesn't allow it
|
||||
// NOTE: We can't restart the service via "kardianos/service"
|
||||
// package, because it kills the process first we can't start a new
|
||||
// instance, because Windows doesn't allow it.
|
||||
//
|
||||
// TODO(a.garipov): Recheck the claim above.
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting: stopping: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cmd := exec.Command(curBinName, os.Args[1:]...)
|
||||
log.Info("Restarting: %v", cmd.Args)
|
||||
cmd := exec.Command(execPath, os.Args[1:]...)
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting:: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
} else {
|
||||
log.Info("Restarting: %v", os.Args)
|
||||
err := syscall.Exec(curBinName, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("syscall.Exec() failed: %s", err)
|
||||
}
|
||||
// Unreachable code
|
||||
}
|
||||
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
err = syscall.Exec(execPath, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("restarting: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Default ports.
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
defaultPortHTTPS = 443
|
||||
defaultPortQUIC = 784
|
||||
defaultPortQUIC = 853
|
||||
defaultPortTLS = 853
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
FindClient: Context.clients.findMultiple,
|
||||
@@ -67,7 +68,6 @@ func initDNSServer() (err error) {
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
Anonymizer: anonymizer,
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
@@ -77,13 +77,36 @@ func initDNSServer() (err error) {
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
Context.dnsFilter = filtering.New(&filterConf, nil)
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
var n *net.IPNet
|
||||
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = n
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.dnsFilter,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
SubnetDetector: Context.subnetDetector,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
DNSFilter: Context.dnsFilter,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
@@ -112,8 +135,13 @@ func initDNSServer() (err error) {
|
||||
return fmt.Errorf("dnsServer.Prepare: %w", err)
|
||||
}
|
||||
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
if config.Clients.Sources.RDNS {
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
}
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
@@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
return
|
||||
}
|
||||
|
||||
if config.DNS.ResolveClients && !ip.IsLoopback() {
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.subnetDetector.IsSpecialNetwork(ip) {
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
@@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
|
||||
}
|
||||
@@ -216,7 +249,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.FilterHandler = applyAdditionalFiltering
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
|
||||
|
||||
newConf.ResolveClients = dnsConf.ResolveClients
|
||||
newConf.ResolveClients = config.Clients.Sources.RDNS
|
||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
|
||||
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
|
||||
@@ -301,24 +334,28 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filtering.Settings) {
|
||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
if clientAddr == nil {
|
||||
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
|
||||
|
||||
if clientIP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
setts.ClientIP = clientAddr
|
||||
setts.ClientIP = clientIP
|
||||
|
||||
c, ok := Context.clients.Find(clientID)
|
||||
if !ok {
|
||||
c, ok = Context.clients.Find(clientAddr.String())
|
||||
c, ok = Context.clients.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
|
||||
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
@@ -359,11 +396,16 @@ func startDNSServer() error {
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
|
||||
if config.DNS.ResolveClients && !ip.IsLoopback() {
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.subnetDetector.IsSpecialNetwork(ip) {
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
@@ -413,7 +455,12 @@ func closeDNSServer() {
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
Context.stats.Close()
|
||||
err := Context.stats.Close()
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Find out if it's safe.
|
||||
Context.stats = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ type homeContext struct {
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
stats stats.Interface // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
@@ -66,8 +66,6 @@ type homeContext struct {
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
subnetDetector *aghnet.SubnetDetector
|
||||
|
||||
// mux is our custom http.ServeMux.
|
||||
mux *http.ServeMux
|
||||
|
||||
@@ -175,6 +173,11 @@ func setupContext(args options) {
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
err = setupHostsContainer()
|
||||
fatalOnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
Context.mux = http.NewServeMux()
|
||||
@@ -287,33 +290,35 @@ func setupConfig(args options) (err error) {
|
||||
ConfName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
if !args.noEtcHosts {
|
||||
if err = setupHostsContainer(); err != nil {
|
||||
return err
|
||||
}
|
||||
var arpdb aghnet.ARPDB
|
||||
if config.Clients.Sources.ARP {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
|
||||
if args.bindPort != 0 {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
args.bindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
config.TLS.PortHTTPS,
|
||||
config.TLS.PortDNSOverTLS,
|
||||
config.TLS.PortDNSOverQUIC,
|
||||
config.TLS.PortDNSCrypt,
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = args.bindPort
|
||||
@@ -390,9 +395,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
// configure log level and output
|
||||
configureLogger(args)
|
||||
|
||||
// Go memory hacks
|
||||
memoryUsage(args)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Println(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
@@ -469,9 +471,6 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
Context.web, err = initWeb(args, clientBuildFS)
|
||||
fatalOnError(err)
|
||||
|
||||
Context.subnetDetector, err = aghnet.NewSubnetDetector()
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNSServer()
|
||||
fatalOnError(err)
|
||||
@@ -521,27 +520,15 @@ func StartMods() error {
|
||||
func checkPermissions() {
|
||||
log.Info("Checking if AdGuard Home has necessary permissions")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows we need to have admin rights to run properly
|
||||
|
||||
admin, _ := aghos.HaveAdminRights()
|
||||
if admin {
|
||||
return
|
||||
}
|
||||
|
||||
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
|
||||
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
|
||||
}
|
||||
|
||||
// We should check if AdGuard Home is able to bind to port 53
|
||||
ok, err := aghnet.CanBindPort(53)
|
||||
|
||||
if ok {
|
||||
log.Info("AdGuard Home can bind to port 53")
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
msg := `Permission check failed.
|
||||
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
log.Fatal(`Permission check failed.
|
||||
|
||||
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
|
||||
Please note, that this is crucial for a server to be able to use privileged ports.
|
||||
@@ -549,16 +536,17 @@ Please note, that this is crucial for a server to be able to use privileged port
|
||||
You have two options:
|
||||
1. Run AdGuard Home with root privileges
|
||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
|
||||
}
|
||||
|
||||
log.Fatal(msg)
|
||||
log.Info(
|
||||
"AdGuard failed to bind to port 53: %s\n\n"+
|
||||
"Please note, that this is crucial for a DNS server to be able to use that port.",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
|
||||
|
||||
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
|
||||
|
||||
log.Info(msg)
|
||||
log.Info("AdGuard Home can bind to port 53")
|
||||
}
|
||||
|
||||
// Write PID to a file
|
||||
@@ -614,17 +602,17 @@ func configureLogger(args options) {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.LogFile = args.logFile
|
||||
} else if config.LogFile != "" {
|
||||
ls.LogFile = config.LogFile
|
||||
ls.File = args.logFile
|
||||
} else if config.File != "" {
|
||||
ls.File = config.File
|
||||
}
|
||||
|
||||
// Handle default log settings overrides
|
||||
ls.LogCompress = config.LogCompress
|
||||
ls.LogLocalTime = config.LogLocalTime
|
||||
ls.LogMaxBackups = config.LogMaxBackups
|
||||
ls.LogMaxSize = config.LogMaxSize
|
||||
ls.LogMaxAge = config.LogMaxAge
|
||||
ls.Compress = config.Compress
|
||||
ls.LocalTime = config.LocalTime
|
||||
ls.MaxBackups = config.MaxBackups
|
||||
ls.MaxSize = config.MaxSize
|
||||
ls.MaxAge = config.MaxAge
|
||||
|
||||
// log.SetLevel(log.INFO) - default
|
||||
if ls.Verbose {
|
||||
@@ -635,27 +623,27 @@ func configureLogger(args options) {
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing
|
||||
// else is configured. Otherwise, we'll simply lose the log output.
|
||||
ls.LogFile = configSyslog
|
||||
ls.File = configSyslog
|
||||
}
|
||||
|
||||
// logs are written to stdout (default)
|
||||
if ls.LogFile == "" {
|
||||
if ls.File == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ls.LogFile == configSyslog {
|
||||
if ls.File == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows
|
||||
err := aghos.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
|
||||
if filepath.IsAbs(ls.LogFile) {
|
||||
logFilePath = ls.LogFile
|
||||
logFilePath := filepath.Join(Context.workDir, ls.File)
|
||||
if filepath.IsAbs(ls.File) {
|
||||
logFilePath = ls.File
|
||||
}
|
||||
|
||||
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
@@ -665,11 +653,11 @@ func configureLogger(args options) {
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.LogCompress, // disabled by default
|
||||
LocalTime: ls.LogLocalTime,
|
||||
MaxBackups: ls.LogMaxBackups,
|
||||
MaxSize: ls.LogMaxSize, // megabytes
|
||||
MaxAge: ls.LogMaxAge, // days
|
||||
Compress: ls.Compress, // disabled by default
|
||||
LocalTime: ls.LocalTime,
|
||||
MaxBackups: ls.MaxBackups,
|
||||
MaxSize: ls.MaxSize, // megabytes
|
||||
MaxAge: ls.MaxAge, // days
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
var allowedLanguages = stringutil.NewSet(
|
||||
"ar",
|
||||
"be",
|
||||
"bg",
|
||||
"cs",
|
||||
@@ -50,7 +51,7 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"zh-tw",
|
||||
)
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
@@ -58,6 +59,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -69,6 +71,7 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
|
||||
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
|
||||
// See this for the details on the performance hits & gains:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
|
||||
func memoryUsage(args options) {
|
||||
if args.disableMemoryOptimization {
|
||||
log.Info("Memory optimization is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Makes Go allocate heap at a slower pace
|
||||
// By default we keep it at 50%
|
||||
debug.SetGCPercent(50)
|
||||
|
||||
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
|
||||
// instead of MADV_FREE on Linux when returning memory to the
|
||||
// kernel. This is less efficient, but causes RSS numbers to drop
|
||||
// more quickly.
|
||||
_ = os.Setenv("GODEBUG", "madvdontneed=1")
|
||||
|
||||
// periodically call "debug.FreeOSMemory" so
|
||||
// that the OS could reclaim the free memory
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
log.Debug("free os memory")
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/google/uuid"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
@@ -47,9 +47,9 @@ type payloadContent struct {
|
||||
|
||||
PayloadType string
|
||||
PayloadIdentifier string
|
||||
PayloadUUID string
|
||||
PayloadDisplayName string
|
||||
PayloadDescription string
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
}
|
||||
|
||||
@@ -63,18 +63,14 @@ const dnsSettingsPayloadType = "com.apple.dnsSettings.managed"
|
||||
type mobileConfig struct {
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadContent []*payloadContent
|
||||
PayloadIdentifier uuid.UUID
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
PayloadRemovalDisallowed bool
|
||||
}
|
||||
|
||||
func genUUIDv4() string {
|
||||
return uuid.NewV4().String()
|
||||
}
|
||||
|
||||
const (
|
||||
dnsProtoHTTPS = "HTTPS"
|
||||
dnsProtoTLS = "TLS"
|
||||
@@ -104,23 +100,23 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("bad dns protocol %q", proto)
|
||||
}
|
||||
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, genUUIDv4())
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, uuid.New())
|
||||
data := &mobileConfig{
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur " +
|
||||
"and iOS 14 or newer systems",
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur and iOS 14 or newer systems",
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadIdentifier: genUUIDv4(),
|
||||
PayloadType: "Configuration",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadContent: []*payloadContent{{
|
||||
DNSSettings: d,
|
||||
|
||||
PayloadType: dnsSettingsPayloadType,
|
||||
PayloadIdentifier: payloadID,
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadDescription: "Configures device to use AdGuard Home",
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
DNSSettings: d,
|
||||
}},
|
||||
PayloadIdentifier: uuid.New(),
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
PayloadRemovalDisallowed: false,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// options passed from command-line arguments
|
||||
@@ -27,10 +28,6 @@ type options struct {
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
runningAsService bool
|
||||
|
||||
// disableMemoryOptimization - disables memory optimization hacks
|
||||
// see memoryUsage() function for the details
|
||||
disableMemoryOptimization bool
|
||||
|
||||
glinetMode bool // Activate GL-Inet compatibility mode
|
||||
|
||||
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be
|
||||
@@ -178,10 +175,14 @@ var noCheckUpdateArg = arg{
|
||||
}
|
||||
|
||||
var disableMemoryOptimizationArg = arg{
|
||||
"Disable memory optimization.",
|
||||
"Deprecated. Disable memory optimization.",
|
||||
"no-mem-optimization", "",
|
||||
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
|
||||
nil, nil, func(_ options, _ string) (f effect, err error) {
|
||||
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var verboseArg = arg{
|
||||
@@ -229,13 +230,19 @@ var helpArg = arg{
|
||||
}
|
||||
|
||||
var noEtcHostsArg = arg{
|
||||
description: "Do not use the OS-provided hosts.",
|
||||
description: "Deprecated. Do not use the OS-provided hosts.",
|
||||
longName: "no-etc-hosts",
|
||||
shortName: "",
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
|
||||
effect: func(_ options, _ string) (f effect, err error) {
|
||||
log.Info(
|
||||
"warning: --no-etc-hosts flag is deprecated and will be removed in the future versions",
|
||||
)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
|
||||
}
|
||||
|
||||
var localFrontendArg = arg{
|
||||
|
||||
@@ -101,9 +101,13 @@ func TestParseDisableUpdate(t *testing.T) {
|
||||
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove after v0.108.0.
|
||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
|
||||
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
|
||||
o, eff, err := parse("", []string{"--no-mem-optimization"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, eff)
|
||||
assert.Zero(t, o)
|
||||
}
|
||||
|
||||
func TestParseService(t *testing.T) {
|
||||
@@ -127,8 +131,6 @@ func TestParseUnknown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
const reportFmt = "expected %s but got %s"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
opts options
|
||||
@@ -173,19 +175,14 @@ func TestSerialize(t *testing.T) {
|
||||
name: "glinet_mode",
|
||||
opts: options{glinetMode: true},
|
||||
ss: []string{"--glinet"},
|
||||
}, {
|
||||
name: "disable_mem_opt",
|
||||
opts: options{disableMemoryOptimization: true},
|
||||
ss: []string{"--no-mem-optimization"},
|
||||
}, {
|
||||
name: "multiple",
|
||||
opts: options{
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
disableMemoryOptimization: true,
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
},
|
||||
ss: []string{
|
||||
"-c", "config",
|
||||
@@ -193,18 +190,13 @@ func TestSerialize(t *testing.T) {
|
||||
"-s", "run",
|
||||
"--pidfile", "pid",
|
||||
"--no-check-update",
|
||||
"--no-mem-optimization",
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := serialize(tc.opts)
|
||||
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
|
||||
|
||||
for i, r := range result {
|
||||
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
|
||||
}
|
||||
assert.ElementsMatch(t, tc.ss, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,18 +16,17 @@ type RDNS struct {
|
||||
exchanger dnsforward.RDNSExchanger
|
||||
clients *clientsContainer
|
||||
|
||||
// usePrivate is used to store the state of current private RDNS
|
||||
// resolving settings and to react to it's changes.
|
||||
// usePrivate is used to store the state of current private RDNS resolving
|
||||
// settings and to react to it's changes.
|
||||
usePrivate uint32
|
||||
|
||||
// ipCh used to pass client's IP to rDNS workerLoop.
|
||||
ipCh chan net.IP
|
||||
|
||||
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
||||
// address stays here while it's inside clients. After leaving clients
|
||||
// the address will be resolved once again. If the address couldn't be
|
||||
// resolved, cache prevents further attempts to resolve it for some
|
||||
// time.
|
||||
// address stays here while it's inside clients. After leaving clients the
|
||||
// address will be resolved once again. If the address couldn't be
|
||||
// resolved, cache prevents further attempts to resolve it for some time.
|
||||
ipCache cache.Cache
|
||||
}
|
||||
|
||||
@@ -125,14 +124,12 @@ func (r *RDNS) workerLoop() {
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
} else if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't handle any errors since AddHost doesn't return non-nil
|
||||
// errors for now.
|
||||
// Don't handle any errors since AddHost doesn't return non-nil errors
|
||||
// for now.
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,16 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{},
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
},
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
@@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
|
||||
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
|
||||
type rDNSExchanger struct {
|
||||
ex aghtest.Exchanger
|
||||
ex upstream.Upstream
|
||||
usePrivate bool
|
||||
}
|
||||
|
||||
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
|
||||
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
|
||||
rev, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: ip.String(),
|
||||
Qtype: dns.TypePTR,
|
||||
Name: dns.Fqdn(rev),
|
||||
Qclass: dns.ClassINET,
|
||||
Qtype: dns.TypePTR,
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
|
||||
MaxCount: defaultRDNSCacheSize,
|
||||
})
|
||||
|
||||
ex := &rDNSExchanger{}
|
||||
ex := &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
}
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
@@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"192.168.1.1": {"local.domain"},
|
||||
"2a00:1450:400c:c06::93": {"ipv6.domain"},
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: errors.Error("1234"),
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
|
||||
testCases := []struct {
|
||||
ups upstream.Upstream
|
||||
@@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
ups: locUpstream,
|
||||
wantLog: "",
|
||||
name: "all_good",
|
||||
cliIP: net.IP{192, 168, 1, 1},
|
||||
cliIP: localIP,
|
||||
}, {
|
||||
ups: errUpstream,
|
||||
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`,
|
||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||
name: "resolve_error",
|
||||
cliIP: net.IP{192, 168, 1, 2},
|
||||
}, {
|
||||
@@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
ch := make(chan net.IP)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.Exchanger{
|
||||
Ups: tc.ups,
|
||||
},
|
||||
ex: tc.ups,
|
||||
},
|
||||
clients: cc,
|
||||
ipCh: ch,
|
||||
|
||||
@@ -433,8 +433,11 @@ EnvironmentFile=-/etc/sysconfig/{{.Name}}
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
// Note: we should keep it in sync with the template from service_sysv_linux.go file
|
||||
// Use "ps | grep -v grep | grep $(get_pid)" because "ps PID" may not work on OpenWrt
|
||||
// sysvScript is the source of the daemon script for SysV-based Linux systems.
|
||||
// Keep as close as possible to the https://github.com/kardianos/service/blob/29f8c79c511bc18422bb99992779f96e6bc33921/service_sysv_linux.go#L187.
|
||||
//
|
||||
// Use ps command instead of reading the procfs since it's a more
|
||||
// implementation-independent approach.
|
||||
const sysvScript = `#!/bin/sh
|
||||
# For RedHat and cousins:
|
||||
# chkconfig: - 99 01
|
||||
@@ -465,7 +468,7 @@ get_pid() {
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -v grep | grep $(get_pid) > /dev/null 2>&1
|
||||
[ -f "$pid_file" ] && ps -p "$(get_pid)" > /dev/null 2>&1
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
@@ -609,7 +612,7 @@ command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirector
|
||||
run_rc_command "$1"
|
||||
`
|
||||
|
||||
const openBSDScript = `#!/bin/sh
|
||||
const openBSDScript = `#!/bin/ksh
|
||||
#
|
||||
# $OpenBSD: {{ .SvcInfo }}
|
||||
|
||||
|
||||
83
internal/home/service_linux.go
Normal file
83
internal/home/service_linux.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func chooseSystem() {
|
||||
sys := service.ChosenSystem()
|
||||
// By default, package service uses the SysV system if it cannot detect
|
||||
// anything other, but the update-rc.d fix should not be applied on OpenWrt,
|
||||
// so exclude it explicitly.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4480 and
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/4677.
|
||||
if sys.String() == "unix-systemv" && !aghos.IsOpenWrt() {
|
||||
service.ChooseSystem(sysvSystem{System: sys})
|
||||
}
|
||||
}
|
||||
|
||||
// sysvSystem is a wrapper for service.System that wraps the service.Service
|
||||
// while creating a new one.
|
||||
//
|
||||
// TODO(e.burkov): File a PR to github.com/kardianos/service.
|
||||
type sysvSystem struct {
|
||||
// System is expected to have an unexported type
|
||||
// *service.linuxSystemService.
|
||||
service.System
|
||||
}
|
||||
|
||||
// New returns a wrapped service.Service.
|
||||
func (sys sysvSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) {
|
||||
s, err = sys.System.New(i, c)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return sysvService{
|
||||
Service: s,
|
||||
name: c.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sysvService is a wrapper for a service.Service that also calls update-rc.d in
|
||||
// a proper way on installing and uninstalling.
|
||||
type sysvService struct {
|
||||
// Service is expected to have an unexported type *service.sysv.
|
||||
service.Service
|
||||
// name stores the name of the service to call updating script with it.
|
||||
name string
|
||||
}
|
||||
|
||||
// Install wraps service.Service.Install call with calling the updating script.
|
||||
func (svc sysvService) Install() (err error) {
|
||||
err = svc.Service.Install()
|
||||
if err != nil {
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Uninstall wraps service.Service.Uninstall call with calling the updating
|
||||
// script.
|
||||
func (svc sysvService) Uninstall() (err error) {
|
||||
err = svc.Service.Uninstall()
|
||||
if err != nil {
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ rc_cmd $1
|
||||
|
||||
// template returns the script template to put into rc.d.
|
||||
func (s *openbsdRunComService) template() (t *template.Template) {
|
||||
tf := map[string]interface{}{
|
||||
tf := map[string]any{
|
||||
"args": func(sl []string) string {
|
||||
return `"` + strings.Join(sl, " ") + `"`
|
||||
},
|
||||
@@ -314,12 +314,13 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
|
||||
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by
|
||||
// something different than the service script's non-existence. Keep it
|
||||
// in mind, when replace the aghos.RunCommand.
|
||||
_, out, err = aghos.RunCommand(scriptPath, cmd)
|
||||
var outData []byte
|
||||
_, outData, err = aghos.RunCommand(scriptPath, cmd)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return "", service.ErrNotInstalled
|
||||
}
|
||||
|
||||
return out, err
|
||||
return string(outData), err
|
||||
}
|
||||
|
||||
// Status implements service.Service interface for *openbsdRunComService.
|
||||
@@ -389,42 +390,42 @@ func newSysLogger(_ string, _ chan<- error) (service.Logger, error) {
|
||||
type sysLogger struct{}
|
||||
|
||||
// Error implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Error(v ...interface{}) error {
|
||||
func (sysLogger) Error(v ...any) error {
|
||||
log.Error(fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Warning implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Warning(v ...interface{}) error {
|
||||
func (sysLogger) Warning(v ...any) error {
|
||||
log.Info("warning: %s", fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Info implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Info(v ...interface{}) error {
|
||||
func (sysLogger) Info(v ...any) error {
|
||||
log.Info(fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Errorf implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Errorf(format string, a ...interface{}) error {
|
||||
func (sysLogger) Errorf(format string, a ...any) error {
|
||||
log.Error(format, a...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Warningf implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Warningf(format string, a ...interface{}) error {
|
||||
func (sysLogger) Warningf(format string, a ...any) error {
|
||||
log.Info("warning: %s", fmt.Sprintf(format, a...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Infof implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Infof(format string, a ...interface{}) error {
|
||||
func (sysLogger) Infof(format string, a ...any) error {
|
||||
log.Info(format, a...)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//go:build !openbsd
|
||||
// +build !openbsd
|
||||
//go:build !(openbsd || linux)
|
||||
// +build !openbsd,!linux
|
||||
|
||||
package home
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {}
|
||||
|
||||
@@ -250,21 +250,17 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
config.BindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
setts.PortHTTPS,
|
||||
setts.PortDNSOverTLS,
|
||||
setts.PortDNSOverQUIC,
|
||||
setts.PortDNSCrypt,
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
|
||||
err = uc.Validate(aghalg.IntIsBefore)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "validating ports: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -343,19 +339,15 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if data.Enabled {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
config.BindPort,
|
||||
config.BetaBindPort,
|
||||
config.DNS.Port,
|
||||
data.PortHTTPS,
|
||||
data.PortDNSOverTLS,
|
||||
data.PortDNSOverQUIC,
|
||||
data.PortDNSCrypt,
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(data.PortHTTPS),
|
||||
tcpPort(data.PortDNSOverTLS),
|
||||
tcpPort(data.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(data.PortDNSOverQUIC),
|
||||
)
|
||||
|
||||
err = uc.Validate(aghalg.IntIsBefore)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -421,6 +413,38 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||
dnsPort, doqPort udpPort,
|
||||
) (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(betaBindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error {
|
||||
log.Tracef("TLS: got certificate: %d bytes", len(certChain))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -17,15 +18,14 @@ import (
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 13
|
||||
const currentSchemaVersion = 14
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
any = interface{}
|
||||
yarr = []any
|
||||
yobj = map[any]any
|
||||
)
|
||||
@@ -86,6 +86,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema10to11,
|
||||
upgradeSchema11to12,
|
||||
upgradeSchema12to13,
|
||||
upgradeSchema13to14,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -104,16 +105,20 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
return fmt.Errorf("unknown configuration schema version %d", oldVersion)
|
||||
}
|
||||
|
||||
body, err := yaml.Marshal(diskConf)
|
||||
buf := &bytes.Buffer{}
|
||||
enc := yaml.NewEncoder(buf)
|
||||
enc.SetIndent(2)
|
||||
|
||||
err = enc.Encode(diskConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating new config: %w", err)
|
||||
}
|
||||
|
||||
config.fileData = body
|
||||
config.fileData = buf.Bytes()
|
||||
confFile := config.getConfigFilename()
|
||||
err = maybe.WriteFile(confFile, body, 0o644)
|
||||
err = maybe.WriteFile(confFile, config.fileData, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving new config: %w", err)
|
||||
return fmt.Errorf("writing new config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -173,11 +178,11 @@ func upgradeSchema2to3(diskConf yobj) error {
|
||||
return fmt.Errorf("no DNS configuration in config file")
|
||||
}
|
||||
|
||||
// Convert interface{} to yobj
|
||||
// Convert any to yobj
|
||||
newDNSConfig := make(yobj)
|
||||
|
||||
switch v := dnsConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
case map[any]any:
|
||||
for k, v := range v {
|
||||
newDNSConfig[fmt.Sprint(k)] = v
|
||||
}
|
||||
@@ -213,12 +218,12 @@ func upgradeSchema3to4(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
|
||||
for i := range arr {
|
||||
switch c := arr[i].(type) {
|
||||
|
||||
case map[interface{}]interface{}:
|
||||
case map[any]any:
|
||||
c["use_global_blocked_services"] = true
|
||||
|
||||
default:
|
||||
@@ -304,11 +309,11 @@ func upgradeSchema5to6(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
for i := range arr {
|
||||
switch c := arr[i].(type) {
|
||||
case map[interface{}]interface{}:
|
||||
var ipVal interface{}
|
||||
case map[any]any:
|
||||
var ipVal any
|
||||
ipVal, ok = c["ip"]
|
||||
ids := []string{}
|
||||
if ok {
|
||||
@@ -323,7 +328,7 @@ func upgradeSchema5to6(diskConf yobj) error {
|
||||
}
|
||||
}
|
||||
|
||||
var macVal interface{}
|
||||
var macVal any
|
||||
macVal, ok = c["mac"]
|
||||
if ok {
|
||||
var mac string
|
||||
@@ -374,7 +379,7 @@ func upgradeSchema6to7(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch dhcp := dhcpVal.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
case map[any]any:
|
||||
var str string
|
||||
str, ok = dhcp["gateway_ip"].(string)
|
||||
if !ok {
|
||||
@@ -726,7 +731,7 @@ func upgradeSchema12to13(diskConf yobj) (err error) {
|
||||
var dhcp yobj
|
||||
dhcp, ok = dhcpVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dhcp: %T", dnsVal)
|
||||
return fmt.Errorf("unexpected type of dhcp: %T", dhcpVal)
|
||||
}
|
||||
|
||||
const field = "local_domain_name"
|
||||
@@ -737,6 +742,68 @@ func upgradeSchema12to13(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema13to14 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'clients':
|
||||
// - 'name': 'client-name'
|
||||
// # …
|
||||
//
|
||||
// # AFTER:
|
||||
// 'clients':
|
||||
// 'persistent':
|
||||
// - 'name': 'client-name'
|
||||
// # …
|
||||
// 'runtime_sources':
|
||||
// 'whois': true
|
||||
// 'arp': true
|
||||
// 'rdns': true
|
||||
// 'dhcp': true
|
||||
// 'hosts': true
|
||||
//
|
||||
func upgradeSchema13to14(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 13 to 14")
|
||||
diskConf["schema_version"] = 14
|
||||
|
||||
clientsVal, ok := diskConf["clients"]
|
||||
if !ok {
|
||||
clientsVal = yarr{}
|
||||
}
|
||||
|
||||
var rdnsSrc bool
|
||||
if dnsVal, dok := diskConf["dns"]; dok {
|
||||
var dnsSettings yobj
|
||||
dnsSettings, ok = dnsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
|
||||
}
|
||||
|
||||
var rdnsSrcVal any
|
||||
rdnsSrcVal, ok = dnsSettings["resolve_clients"]
|
||||
if ok {
|
||||
rdnsSrc, ok = rdnsSrcVal.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of resolve_clients: %T", rdnsSrcVal)
|
||||
}
|
||||
|
||||
delete(dnsSettings, "resolve_clients")
|
||||
}
|
||||
}
|
||||
|
||||
diskConf["clients"] = yobj{
|
||||
"persistent": clientsVal,
|
||||
"runtime_sources": &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: rdnsSrc,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -190,7 +190,7 @@ func testDiskConf(schemaVersion int) (diskConf yobj) {
|
||||
return diskConf
|
||||
}
|
||||
|
||||
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v2 would
|
||||
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v3 would
|
||||
// unmarshal it. In YAML, keys aren't guaranteed to always only be strings.
|
||||
func testDNSConf(schemaVersion int) (dnsConf yobj) {
|
||||
dnsConf = yobj{
|
||||
@@ -500,7 +500,7 @@ func TestUpgradeSchema11to12(t *testing.T) {
|
||||
dnsVal, ok = dns.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var ivl interface{}
|
||||
var ivl any
|
||||
ivl, ok = dnsVal["querylog_interval"]
|
||||
require.True(t, ok)
|
||||
|
||||
@@ -513,46 +513,129 @@ func TestUpgradeSchema11to12(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpgradeSchema12to13(t *testing.T) {
|
||||
t.Run("no_dns", func(t *testing.T) {
|
||||
conf := yobj{}
|
||||
const newSchemaVer = 13
|
||||
|
||||
err := upgradeSchema12to13(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, conf["schema_version"], 13)
|
||||
})
|
||||
|
||||
t.Run("no_dhcp", func(t *testing.T) {
|
||||
conf := yobj{
|
||||
"dns": yobj{},
|
||||
}
|
||||
|
||||
err := upgradeSchema12to13(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, conf["schema_version"], 13)
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
conf := yobj{
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{},
|
||||
want: yobj{"schema_version": newSchemaVer},
|
||||
name: "no_dns",
|
||||
}, {
|
||||
in: yobj{"dns": yobj{}},
|
||||
want: yobj{
|
||||
"dns": yobj{},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "no_dhcp",
|
||||
}, {
|
||||
in: yobj{
|
||||
"dns": yobj{
|
||||
"local_domain_name": "lan",
|
||||
},
|
||||
"dhcp": yobj{},
|
||||
"schema_version": 12,
|
||||
}
|
||||
|
||||
wantConf := yobj{
|
||||
"schema_version": newSchemaVer - 1,
|
||||
},
|
||||
want: yobj{
|
||||
"dns": yobj{},
|
||||
"dhcp": yobj{
|
||||
"local_domain_name": "lan",
|
||||
},
|
||||
"schema_version": 13,
|
||||
}
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "good",
|
||||
}}
|
||||
|
||||
err := upgradeSchema12to13(conf)
|
||||
require.NoError(t, err)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema12to13(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantConf, conf)
|
||||
})
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema13to14(t *testing.T) {
|
||||
const newSchemaVer = 14
|
||||
|
||||
testClient := &clientObject{
|
||||
Name: "agh-client",
|
||||
IDs: []string{"id1"},
|
||||
UseGlobalSettings: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
// The clients field will be added anyway.
|
||||
"clients": yobj{
|
||||
"persistent": yarr{},
|
||||
"runtime_sources": &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: false,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "no_clients",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": []*clientObject{testClient},
|
||||
},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
"clients": yobj{
|
||||
"persistent": []*clientObject{testClient},
|
||||
"runtime_sources": &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: false,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "no_dns",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": []*clientObject{testClient},
|
||||
"dns": yobj{
|
||||
"resolve_clients": true,
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
"clients": yobj{
|
||||
"persistent": []*clientObject{testClient},
|
||||
"runtime_sources": &clientSourcesConf{
|
||||
WHOIS: true,
|
||||
ARP: true,
|
||||
RDNS: true,
|
||||
DHCP: true,
|
||||
HostsFile: true,
|
||||
},
|
||||
},
|
||||
"dns": yobj{},
|
||||
},
|
||||
name: "good",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema13to14(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user