Merge branch 'master' into 4387-fix-openapi-schema

This commit is contained in:
Ainar Garipov
2022-08-17 20:21:32 +03:00
228 changed files with 15549 additions and 5499 deletions

View File

@@ -2,10 +2,9 @@ package home
import (
"bytes"
"encoding"
"fmt"
"net"
"os/exec"
"runtime"
"sort"
"strings"
"sync"
@@ -56,12 +55,49 @@ type clientSource uint
// Client sources. The order determines the priority.
const (
ClientSourceWHOIS clientSource = iota
ClientSourceRDNS
ClientSourceARP
ClientSourceRDNS
ClientSourceDHCP
ClientSourceHostsFile
)
var _ fmt.Stringer = clientSource(0)
// String returns a human-readable name of cs.
func (cs clientSource) String() (s string) {
switch cs {
case ClientSourceWHOIS:
return "WHOIS"
case ClientSourceARP:
return "ARP"
case ClientSourceRDNS:
return "rDNS"
case ClientSourceDHCP:
return "DHCP"
case ClientSourceHostsFile:
return "etc/hosts"
default:
return ""
}
}
var _ encoding.TextMarshaler = clientSource(0)
// MarshalText implements encoding.TextMarshaler for the clientSource.
func (cs clientSource) MarshalText() (text []byte, err error) {
return []byte(cs.String()), nil
}
// clientSourceConf is used to configure where the runtime clients will be
// obtained from.
type clientSourcesConf struct {
WHOIS bool `yaml:"whois"`
ARP bool `yaml:"arp"`
RDNS bool `yaml:"rdns"`
DHCP bool `yaml:"dhcp"`
HostsFile bool `yaml:"hosts"`
}
// RuntimeClient information
type RuntimeClient struct {
WHOISInfo *RuntimeClientWHOISInfo
@@ -99,6 +135,9 @@ type clientsContainer struct {
// hosts database.
etcHosts *aghnet.HostsContainer
// arpdb stores the neighbors retrieved from ARP.
arpdb aghnet.ARPDB
testing bool // if TRUE, this object is used for internal tests
}
@@ -109,6 +148,7 @@ func (clients *clientsContainer) Init(
objects []*clientObject,
dhcpServer *dhcpd.Server,
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
) {
if clients.list != nil {
log.Fatal("clients.list != nil")
@@ -121,6 +161,7 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects)
if clients.testing {
@@ -132,14 +173,14 @@ func (clients *clientsContainer) Init(
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
go clients.handleHostsUpdates()
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
}
func (clients *clientsContainer) handleHostsUpdates() {
if clients.etcHosts != nil {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
}
}
@@ -156,7 +197,9 @@ func (clients *clientsContainer) Start() {
// Reload reloads runtime clients.
func (clients *clientsContainer) Reload() {
clients.addFromSystemARP()
if clients.arpdb != nil {
clients.addFromSystemARP()
}
}
type clientObject struct {
@@ -255,6 +298,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
}
func (clients *clientsContainer) periodicUpdate() {
defer log.OnPanic("clients container")
for {
clients.Reload()
time.Sleep(clientsUpdatePeriod)
@@ -380,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
c.Tags = stringutil.CloneSlice(c.Tags)
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
return c, true
}
@@ -476,7 +522,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
var v interface{}
var v any
v, ok = clients.ipToRC.Get(ip)
if !ok {
return nil, false
@@ -530,7 +576,7 @@ func (clients *clientsContainer) check(c *Client) (err error) {
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
c.IDs[i] = id
c.IDs[i] = strings.ToLower(id)
} else {
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
}
@@ -721,20 +767,18 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc
clients.lock.Lock()
defer clients.lock.Unlock()
ok = clients.addHostLocked(ip, host, src)
return ok, nil
return clients.addHostLocked(ip, host, src), nil
}
// addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
var rc *RuntimeClient
rc, ok = clients.findRuntimeClientLocked(ip)
rc, ok := clients.findRuntimeClientLocked(ip)
if ok {
if rc.Source > src {
return false
}
rc.Host = host
rc.Source = src
} else {
rc = &RuntimeClient{
@@ -754,7 +798,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
@@ -782,41 +826,38 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
hosts, ok := v.(*stringutil.Set)
hosts.Range(func(ip net.IP, v any) (cont bool) {
rec, ok := v.(*aghnet.HostsRecord)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
return true
}
hosts.Range(func(name string) (cont bool) {
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
n++
}
return true
})
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
n++
return true
})
log.Debug("clients: added %d client aliases from system hosts-file", n)
log.Debug("clients: added %d client aliases from system hosts file", n)
}
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
if err := clients.arpdb.Refresh(); err != nil {
log.Error("refreshing arp container: %s", err)
clients.arpdb = aghnet.EmptyARPDB{}
return
}
cmd := exec.Command("arp", "-a")
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
data, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Debug("command %q has failed: %q code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
ns := clients.arpdb.Neighbors()
if len(ns) == 0 {
log.Debug("refreshing arp container: the update is empty")
return
}
@@ -825,36 +866,20 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.rmHostsBySrc(ClientSourceARP)
n := 0
// TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
lparen := strings.Index(ln, " (")
rparen := strings.Index(ln, ") ")
if lparen == -1 || rparen == -1 || lparen >= rparen {
continue
}
host := ln[:lparen]
ipStr := ln[lparen+2 : rparen]
ip := net.ParseIP(ipStr)
if netutil.ValidateDomainName(host) != nil || ip == nil {
continue
}
ok := clients.addHostLocked(ip, host, ClientSourceARP)
if ok {
n++
added := 0
for _, n := range ns {
if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) {
added++
}
}
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
log.Debug("clients: added %d client aliases from arp neighborhood", added)
}
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) updateFromDHCP(add bool) {
if clients.dhcpServer == nil {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}

View File

@@ -18,7 +18,7 @@ func TestClients(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
c := &Client{
@@ -194,7 +194,7 @@ func TestClientsWHOIS(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
whois := &RuntimeClientWHOISInfo{
Country: "AU",
Orgname: "Example Org",
@@ -253,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
@@ -332,7 +332,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
clients.Init(nil, nil, nil, nil)
// Add client with upstreams.
ok, err := clients.Add(&Client{

View File

@@ -47,9 +47,9 @@ type clientJSON struct {
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
Source string `json:"source"`
IP net.IP `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
@@ -70,7 +70,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.Clients = append(data.Clients, cj)
}
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
rc, ok := v.(*RuntimeClient)
if !ok {
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
@@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IP: ip,
}
cj.Source = "etc/hosts"
switch rc.Source {
case ClientSourceDHCP:
cj.Source = "DHCP"
case ClientSourceRDNS:
cj.Source = "rDNS"
case ClientSourceARP:
cj.Source = "ARP"
case ClientSourceWHOIS:
cj.Source = "WHOIS"
Name: rc.Host,
Source: rc.Source,
IP: ip,
}
data.RuntimeClients = append(data.RuntimeClients, cj)
@@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
return
}
@@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
rc, ok := clients.FindRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@@ -1,6 +1,7 @@
package home
import (
"bytes"
"fmt"
"net"
"os"
@@ -19,7 +20,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
const (
@@ -27,15 +28,36 @@ const (
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// logSettings
// logSettings are the logging settings part of the configuration file.
//
// TODO(a.garipov): Put them into a separate object.
type logSettings struct {
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
// File is the path to the log file. If empty, logs are written to stdout.
// If "syslog", logs are written to syslog.
File string `yaml:"log_file"`
// MaxBackups is the maximum number of old log files to retain.
//
// NOTE: MaxAge may still cause them to get deleted.
MaxBackups int `yaml:"log_max_backups"`
// MaxSize is the maximum size of the log file before it gets rotated, in
// megabytes. The default value is 100 MB.
MaxSize int `yaml:"log_max_size"`
// MaxAge is the maximum duration for retaining old log files, in days.
MaxAge int `yaml:"log_max_age"`
// Compress determines, if the rotated log files should be compressed using
// gzip.
Compress bool `yaml:"log_compress"`
// LocalTime determines, if the time used for formatting the timestamps in
// is the computer's local time.
LocalTime bool `yaml:"log_localtime"`
// Verbose determines, if verbose (aka debug) logging is enabled.
Verbose bool `yaml:"verbose"`
}
// osConfig contains OS-related configuration.
@@ -51,6 +73,13 @@ type osConfig struct {
RlimitNoFile uint64 `yaml:"rlimit_nofile"`
}
type clientsConfig struct {
// Sources defines the set of sources to fetch the runtime clients from.
Sources *clientSourcesConf `yaml:"runtime_sources"`
// Persistent are the configured clients.
Persistent []*clientObject `yaml:"persistent"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
type configuration struct {
@@ -88,7 +117,7 @@ type configuration struct {
// Clients contains the YAML representations of the persistent clients.
// This field is only used for reading and writing persistent client data.
// Keep this field sorted to ensure consistent ordering.
Clients []*clientObject `yaml:"clients"`
Clients *clientsConfig `yaml:"clients"`
logSettings `yaml:",inline"`
@@ -123,8 +152,9 @@ type dnsConfig struct {
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
// ResolveClients enables and disables resolving clients with RDNS.
ResolveClients bool `yaml:"resolve_clients"`
// PrivateNets is the set of IP networks for which the private reverse DNS
// resolver should be used.
PrivateNets []string `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
@@ -179,6 +209,7 @@ var config = &configuration{
Ratelimit: 20,
RefuseAny: true,
AllServers: false,
HandleDDR: true,
FastestTimeout: timeutil.Duration{
Duration: fastip.DefaultPingWaitTimeout,
},
@@ -194,7 +225,6 @@ var config = &configuration{
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
ResolveClients: true,
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
@@ -205,12 +235,21 @@ var config = &configuration{
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
},
Clients: &clientsConfig{
Sources: &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: true,
DHCP: true,
HostsFile: true,
},
},
logSettings: logSettings{
LogCompress: false,
LogLocalTime: false,
LogMaxBackups: 0,
LogMaxSize: 100,
LogMaxAge: 3,
Compress: false,
LocalTime: false,
MaxBackups: 0,
MaxSize: 100,
MaxAge: 3,
},
OSConfig: &osConfig{},
SchemaVersion: currentSchemaVersion,
@@ -285,25 +324,28 @@ func parseConfig() (err error) {
return err
}
uc := aghalg.UniqChecker{}
addPorts(
uc,
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
)
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort), tcpPort(config.BetaBindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
// TODO(e.burkov): Consider adding a udpPort with the same value when
// we add support for HTTP/3 for web admin interface.
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
@@ -317,8 +359,14 @@ func parseConfig() (err error) {
return nil
}
// addPorts is a helper for ports validation. It skips zero ports.
func addPorts(uc aghalg.UniqChecker, ports ...int) {
// udpPort is the port number for UDP protocol.
type udpPort int
// tcpPort is the port number for TCP protocol.
type tcpPort int
// addPorts is a helper for ports validation that skips zero ports.
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
for _, p := range ports {
if p != 0 {
uc.Add(p)
@@ -340,13 +388,14 @@ func readConfigFile() (fileData []byte, err error) {
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() error {
func (c *configuration) write() (err error) {
c.Lock()
defer c.Unlock()
if Context.auth != nil {
config.Users = Context.auth.GetUsers()
}
if Context.tls != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
@@ -380,9 +429,7 @@ func (c *configuration) write() error {
s.WriteDiskConfig(&c)
dns := &config.DNS
dns.FilteringConfig = c
dns.LocalPTRResolvers,
dns.ResolveClients,
dns.UsePrivateRDNS = s.RDNSSettings()
dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings()
}
if Context.dhcpServer != nil {
@@ -391,22 +438,23 @@ func (c *configuration) write() error {
config.DHCP = c
}
config.Clients = Context.clients.forConfig()
config.Clients.Persistent = Context.clients.forConfig()
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)
if err != nil {
log.Error("Couldn't generate YAML file: %s", err)
log.Debug("writing config file %q", configFile)
return err
buf := &bytes.Buffer{}
enc := yaml.NewEncoder(buf)
enc.SetIndent(2)
err = enc.Encode(config)
if err != nil {
return fmt.Errorf("generating config file: %w", err)
}
err = maybe.WriteFile(configFile, yamlText, 0o644)
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
if err != nil {
log.Error("Couldn't save YAML config: %s", err)
return err
return fmt.Errorf("writing config file: %w", err)
}
return nil

View File

@@ -189,7 +189,7 @@ func registerControlHandlers() {
RegisterAuthHandlers()
}
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))

View File

@@ -105,19 +105,22 @@ type checkConfResp struct {
// validateWeb returns error is the web part if the initial configuration can't
// be set.
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.Web.Port
addPorts(uc, config.BetaBindPort, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
// Avoid duplicating the error into the status of DNS.
uc[port] = 1
portInt := req.Web.Port
port := tcpPort(portInt)
addPorts(tcpPorts, tcpPort(config.BetaBindPort), port)
if err = tcpPorts.Validate(); err != nil {
// Reset the value for the port to 1 to make sure that validateDNS
// doesn't throw the same error, unless the same TCP port is set there
// as well.
tcpPorts[port] = 1
return err
}
switch port {
switch portInt {
case 0, config.BindPort:
return nil
default:
@@ -125,21 +128,18 @@ func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
// unbound after install.
}
return aghnet.CheckPort("tcp", req.Web.IP, port)
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
}
// validateDNS returns error if the DNS part of the initial configuration can't
// be set. canAutofix is true if the port can be unbound by AdGuard Home
// automatically.
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
func (req *checkConfReq) validateDNS(
tcpPorts aghalg.UniqChecker[tcpPort],
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.DNS.Port
addPorts(uc, port)
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return false, err
}
switch port {
case 0:
return false, nil
@@ -148,6 +148,11 @@ func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, er
// by AdGuard Home for web interface.
default:
// Check TCP as well.
addPorts(tcpPorts, tcpPort(port))
if err = tcpPorts.Validate(); err != nil {
return false, err
}
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
if err != nil {
return false, err
@@ -185,13 +190,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
}
resp := &checkConfResp{}
uc := aghalg.UniqChecker{}
if err = req.validateWeb(uc); err != nil {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
if err = req.validateWeb(tcpPorts); err != nil {
resp.Web.Status = err.Error()
}
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
@@ -212,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := aghnet.GetInterfaceByIP(ip)
interfaceName := aghnet.InterfaceByIP(ip)
resp.Static = "no"
if len(interfaceName) == 0 {

View File

@@ -3,14 +3,15 @@ package home
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
@@ -27,12 +28,16 @@ type temporaryError interface {
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
resp := &versionResponse{}
if Context.disableUpdate {
// w.Header().Set("Content-Type", "application/json")
resp.Disabled = true
_ = json.NewEncoder(w).Encode(resp)
// TODO(e.burkov): Add error handling and deal with headers.
err := json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
return
}
@@ -44,30 +49,48 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "JSON parse: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "parsing request: %s", err)
return
}
}
err = requestVersionInfo(resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
return
}
err = resp.setAllowedToAutoUpdate()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
// update server.
func requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
for i := 0; i != 3; i++ {
func() {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
}()
resp.VersionInfo, err = Context.updater.VersionInfo(recheck)
if err != nil {
var terr temporaryError
if errors.As(err, &terr) && terr.Temporary() {
// Temporary network error. This case may happen while
// we're restarting our DNS server. Log and sleep for
// some time.
// Temporary network error. This case may happen while we're
// restarting our DNS server. Log and sleep for some time.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
d := time.Duration(i) * time.Second
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
log.Info("update: temp net error: %q; sleeping for %s and retrying", err, d)
time.Sleep(d)
continue
@@ -76,29 +99,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
break
}
if err != nil {
vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb.
aghhttp.Error(
r,
w,
http.StatusBadGateway,
"Couldn't get version check json from %s: %T %s\n",
vcu,
err,
err,
)
return
return fmt.Errorf("getting version info from %s: %s", vcu, err)
}
resp.confirmAutoUpdate()
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
return nil
}
// handleUpdate performs an update to the latest available version procedure.
@@ -109,7 +117,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
return
}
err := Context.updater.Update()
// Retain the current absolute path of the executable, since the updater is
// likely to change the position current one to the backup directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4735.
execPath, err := os.Executable()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "getting path: %s", err)
return
}
err = Context.updater.Update()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -121,85 +140,88 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
f.Flush()
}
// The background context is used because the underlying functions wrap
// it with timeout and shut down the server, which handles current
// request. It also should be done in a separate goroutine due to the
// same reason.
go func() {
finishUpdate(context.Background())
}()
// The background context is used because the underlying functions wrap it
// with timeout and shut down the server, which handles current request. It
// also should be done in a separate goroutine for the same reason.
go finishUpdate(context.Background(), execPath)
}
// versionResponse is the response for /control/version.json endpoint.
type versionResponse struct {
Disabled bool `json:"disabled"`
updater.VersionInfo
Disabled bool `json:"disabled"`
}
// confirmAutoUpdate checks the real possibility of auto update.
func (vr *versionResponse) confirmAutoUpdate() {
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
canUpdate := true
var tlsConf *tlsConfigSettings
if runtime.GOOS != "windows" {
tlsConf = &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
}
if tlsConf != nil &&
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
tlsConf.PortDNSOverTLS < 1024 ||
tlsConf.PortDNSOverQUIC < 1024)) ||
config.BindPort < 1024 ||
config.DNS.Port < 1024) {
canUpdate, _ = aghnet.CanBindPrivilegedPorts()
}
vr.CanAutoUpdate = &canUpdate
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS.
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
if vr.CanAutoUpdate != aghalg.NBTrue {
return nil
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
}
}
vr.CanAutoUpdate = aghalg.BoolToNullBool(canUpdate)
return nil
}
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
}
// finishUpdate completes an update procedure.
func finishUpdate(ctx context.Context) {
log.Info("Stopping all tasks")
func finishUpdate(ctx context.Context, execPath string) {
var err error
log.Info("stopping all tasks")
cleanup(ctx)
cleanupAlways()
exeName := "AdGuardHome"
if runtime.GOOS == "windows" {
exeName = "AdGuardHome.exe"
}
curBinName := filepath.Join(Context.workDir, exeName)
if runtime.GOOS == "windows" {
if Context.runningAsService {
// Note:
// we can't restart the service via "kardianos/service" package - it kills the process first
// we can't start a new instance - Windows doesn't allow it
// NOTE: We can't restart the service via "kardianos/service"
// package, because it kills the process first we can't start a new
// instance, because Windows doesn't allow it.
//
// TODO(a.garipov): Recheck the claim above.
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err := cmd.Start()
err = cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
log.Fatalf("restarting: stopping: %s", err)
}
os.Exit(0)
}
cmd := exec.Command(curBinName, os.Args[1:]...)
log.Info("Restarting: %v", cmd.Args)
cmd := exec.Command(execPath, os.Args[1:]...)
log.Info("restarting: %q %q", execPath, os.Args[1:])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
err = cmd.Start()
if err != nil {
log.Fatalf("exec.Command() failed: %s", err)
log.Fatalf("restarting:: %s", err)
}
os.Exit(0)
} else {
log.Info("Restarting: %v", os.Args)
err := syscall.Exec(curBinName, os.Args, os.Environ())
if err != nil {
log.Fatalf("syscall.Exec() failed: %s", err)
}
// Unreachable code
}
log.Info("restarting: %q %q", execPath, os.Args[1:])
err = syscall.Exec(execPath, os.Args, os.Environ())
if err != nil {
log.Fatalf("restarting: %s", err)
}
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
// Default ports.
@@ -25,7 +25,7 @@ const (
defaultPortDNS = 53
defaultPortHTTP = 80
defaultPortHTTPS = 443
defaultPortQUIC = 784
defaultPortQUIC = 853
defaultPortTLS = 853
)
@@ -58,6 +58,7 @@ func initDNSServer() (err error) {
}
conf := querylog.Config{
Anonymizer: anonymizer,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple,
@@ -67,7 +68,6 @@ func initDNSServer() (err error) {
Enabled: config.DNS.QueryLogEnabled,
FileEnabled: config.DNS.QueryLogFileEnabled,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
Anonymizer: anonymizer,
}
Context.queryLog = querylog.New(conf)
@@ -77,13 +77,36 @@ func initDNSServer() (err error) {
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil)
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
case 0:
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = netutil.SliceSubnetSet(nets)
}
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
SubnetDetector: Context.subnetDetector,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
DNSFilter: Context.dnsFilter,
Stats: Context.stats,
QueryLog: Context.queryLog,
PrivateNets: privateNets,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
@@ -112,8 +135,13 @@ func initDNSServer() (err error) {
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
Context.whois = initWHOIS(&Context.clients)
if config.Clients.Sources.RDNS {
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
if config.Clients.Sources.WHOIS {
Context.whois = initWHOIS(&Context.clients)
}
Context.filters.Init()
return nil
@@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) {
return
}
if config.DNS.ResolveClients && !ip.IsLoopback() {
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}
@@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.TLSConfig = tlsConf.TLSConfig
newConf.TLSConfig.ServerName = tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
}
if tlsConf.PortDNSOverTLS != 0 {
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
}
@@ -216,7 +249,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.FilterHandler = applyAdditionalFiltering
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
newConf.ResolveClients = dnsConf.ResolveClients
newConf.ResolveClients = config.Clients.Sources.RDNS
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration
@@ -301,24 +334,28 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *filtering.Settings) {
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if clientAddr == nil {
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
if clientIP == nil {
return
}
setts.ClientIP = clientAddr
setts.ClientIP = clientIP
c, ok := Context.clients.Find(clientID)
if !ok {
c, ok = Context.clients.Find(clientAddr.String())
c, ok = Context.clients.Find(clientIP.String())
if !ok {
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
return
}
}
log.Debug("using settings for client %s with ip %s and clientid %q", c.Name, clientAddr, clientID)
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
@@ -359,11 +396,16 @@ func startDNSServer() error {
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if config.DNS.ResolveClients && !ip.IsLoopback() {
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
if ip == nil {
continue
}
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.subnetDetector.IsSpecialNetwork(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
Context.whois.Begin(ip)
}
}
@@ -413,7 +455,12 @@ func closeDNSServer() {
}
if Context.stats != nil {
Context.stats.Close()
err := Context.stats.Close()
if err != nil {
log.Debug("closing stats: %s", err)
}
// TODO(e.burkov): Find out if it's safe.
Context.stats = nil
}

View File

@@ -47,7 +47,7 @@ type homeContext struct {
// --
clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module
stats stats.Interface // statistics module
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
@@ -66,8 +66,6 @@ type homeContext struct {
updater *updater.Updater
subnetDetector *aghnet.SubnetDetector
// mux is our custom http.ServeMux.
mux *http.ServeMux
@@ -175,6 +173,11 @@ func setupContext(args options) {
os.Exit(0)
}
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer()
fatalOnError(err)
}
}
Context.mux = http.NewServeMux()
@@ -287,33 +290,35 @@ func setupConfig(args options) (err error) {
ConfName: config.getConfigFilename(),
})
if !args.noEtcHosts {
if err = setupHostsContainer(); err != nil {
return err
}
var arpdb aghnet.ARPDB
if config.Clients.Sources.ARP {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts)
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
if args.bindPort != 0 {
uc := aghalg.UniqChecker{}
addPorts(
uc,
args.bindPort,
config.BetaBindPort,
config.DNS.Port,
)
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
if config.TLS.Enabled {
addPorts(
uc,
config.TLS.PortHTTPS,
config.TLS.PortDNSOverTLS,
config.TLS.PortDNSOverQUIC,
config.TLS.PortDNSCrypt,
tcpPorts,
tcpPort(config.TLS.PortHTTPS),
tcpPort(config.TLS.PortDNSOverTLS),
tcpPort(config.TLS.PortDNSCrypt),
)
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
}
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
return fmt.Errorf("validating ports: %w", err)
if err = tcpPorts.Validate(); err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
} else if err = udpPorts.Validate(); err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
config.BindPort = args.bindPort
@@ -390,9 +395,6 @@ func run(args options, clientBuildFS fs.FS) {
// configure log level and output
configureLogger(args)
// Go memory hacks
memoryUsage(args)
// Print the first message after logger is configured.
log.Println(version.Full())
log.Debug("current working directory is %s", Context.workDir)
@@ -469,9 +471,6 @@ func run(args options, clientBuildFS fs.FS) {
Context.web, err = initWeb(args, clientBuildFS)
fatalOnError(err)
Context.subnetDetector, err = aghnet.NewSubnetDetector()
fatalOnError(err)
if !Context.firstRun {
err = initDNSServer()
fatalOnError(err)
@@ -521,27 +520,15 @@ func StartMods() error {
func checkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions")
if runtime.GOOS == "windows" {
// On Windows we need to have admin rights to run properly
admin, _ := aghos.HaveAdminRights()
if admin {
return
}
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
}
// We should check if AdGuard Home is able to bind to port 53
ok, err := aghnet.CanBindPort(53)
if ok {
log.Info("AdGuard Home can bind to port 53")
return
}
if errors.Is(err, os.ErrPermission) {
msg := `Permission check failed.
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
if err != nil {
if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports.
@@ -549,16 +536,17 @@ Please note, that this is crucial for a server to be able to use privileged port
You have two options:
1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
}
log.Fatal(msg)
log.Info(
"AdGuard failed to bind to port 53: %s\n\n"+
"Please note, that this is crucial for a DNS server to be able to use that port.",
err,
)
}
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
log.Info(msg)
log.Info("AdGuard Home can bind to port 53")
}
// Write PID to a file
@@ -614,17 +602,17 @@ func configureLogger(args options) {
ls.Verbose = true
}
if args.logFile != "" {
ls.LogFile = args.logFile
} else if config.LogFile != "" {
ls.LogFile = config.LogFile
ls.File = args.logFile
} else if config.File != "" {
ls.File = config.File
}
// Handle default log settings overrides
ls.LogCompress = config.LogCompress
ls.LogLocalTime = config.LogLocalTime
ls.LogMaxBackups = config.LogMaxBackups
ls.LogMaxSize = config.LogMaxSize
ls.LogMaxAge = config.LogMaxAge
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
// log.SetLevel(log.INFO) - default
if ls.Verbose {
@@ -635,27 +623,27 @@ func configureLogger(args options) {
// happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output.
ls.LogFile = configSyslog
ls.File = configSyslog
}
// logs are written to stdout (default)
if ls.LogFile == "" {
if ls.File == "" {
return
}
if ls.LogFile == configSyslog {
if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows
err := aghos.ConfigureSyslog(serviceName)
if err != nil {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
if filepath.IsAbs(ls.LogFile) {
logFilePath = ls.LogFile
logFilePath := filepath.Join(Context.workDir, ls.File)
if filepath.IsAbs(ls.File) {
logFilePath = ls.File
}
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
@@ -665,11 +653,11 @@ func configureLogger(args options) {
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.LogCompress, // disabled by default
LocalTime: ls.LogLocalTime,
MaxBackups: ls.LogMaxBackups,
MaxSize: ls.LogMaxSize, // megabytes
MaxAge: ls.LogMaxAge, // days
Compress: ls.Compress, // disabled by default
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize, // megabytes
MaxAge: ls.MaxAge, // days
})
}
}

View File

@@ -13,6 +13,7 @@ import (
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
var allowedLanguages = stringutil.NewSet(
"ar",
"be",
"bg",
"cs",
@@ -50,7 +51,7 @@ var allowedLanguages = stringutil.NewSet(
"zh-tw",
)
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
@@ -58,6 +59,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
msg := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}
}
@@ -69,6 +71,7 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
msg := fmt.Sprintf("failed to read request body: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
return
}

View File

@@ -1,40 +0,0 @@
package home
import (
"os"
"runtime/debug"
"time"
"github.com/AdguardTeam/golibs/log"
)
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
// See this for the details on the performance hits & gains:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/2044#issuecomment-687042211
func memoryUsage(args options) {
if args.disableMemoryOptimization {
log.Info("Memory optimization is disabled")
return
}
// Makes Go allocate heap at a slower pace
// By default we keep it at 50%
debug.SetGCPercent(50)
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
// instead of MADV_FREE on Linux when returning memory to the
// kernel. This is less efficient, but causes RSS numbers to drop
// more quickly.
_ = os.Setenv("GODEBUG", "madvdontneed=1")
// periodically call "debug.FreeOSMemory" so
// that the OS could reclaim the free memory
go func() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
log.Debug("free os memory")
debug.FreeOSMemory()
}
}()
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
uuid "github.com/satori/go.uuid"
"github.com/google/uuid"
"howett.net/plist"
)
@@ -47,9 +47,9 @@ type payloadContent struct {
PayloadType string
PayloadIdentifier string
PayloadUUID string
PayloadDisplayName string
PayloadDescription string
PayloadUUID uuid.UUID
PayloadVersion int
}
@@ -63,18 +63,14 @@ const dnsSettingsPayloadType = "com.apple.dnsSettings.managed"
type mobileConfig struct {
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadType string
PayloadUUID string
PayloadContent []*payloadContent
PayloadIdentifier uuid.UUID
PayloadUUID uuid.UUID
PayloadVersion int
PayloadRemovalDisallowed bool
}
func genUUIDv4() string {
return uuid.NewV4().String()
}
const (
dnsProtoHTTPS = "HTTPS"
dnsProtoTLS = "TLS"
@@ -104,23 +100,23 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
return nil, fmt.Errorf("bad dns protocol %q", proto)
}
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, genUUIDv4())
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, uuid.New())
data := &mobileConfig{
PayloadDescription: "Adds AdGuard Home to macOS Big Sur " +
"and iOS 14 or newer systems",
PayloadDescription: "Adds AdGuard Home to macOS Big Sur and iOS 14 or newer systems",
PayloadDisplayName: dspName,
PayloadIdentifier: genUUIDv4(),
PayloadType: "Configuration",
PayloadUUID: genUUIDv4(),
PayloadContent: []*payloadContent{{
DNSSettings: d,
PayloadType: dnsSettingsPayloadType,
PayloadIdentifier: payloadID,
PayloadUUID: genUUIDv4(),
PayloadDisplayName: dspName,
PayloadDescription: "Configures device to use AdGuard Home",
PayloadUUID: uuid.New(),
PayloadVersion: 1,
DNSSettings: d,
}},
PayloadIdentifier: uuid.New(),
PayloadUUID: uuid.New(),
PayloadVersion: 1,
PayloadRemovalDisallowed: false,
}

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
)
// options passed from command-line arguments
@@ -27,10 +28,6 @@ type options struct {
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
// disableMemoryOptimization - disables memory optimization hacks
// see memoryUsage() function for the details
disableMemoryOptimization bool
glinetMode bool // Activate GL-Inet compatibility mode
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be
@@ -178,10 +175,14 @@ var noCheckUpdateArg = arg{
}
var disableMemoryOptimizationArg = arg{
"Disable memory optimization.",
"Deprecated. Disable memory optimization.",
"no-mem-optimization", "",
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
nil, nil, func(_ options, _ string) (f effect, err error) {
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
return nil, nil
},
func(o options) []string { return nil },
}
var verboseArg = arg{
@@ -229,13 +230,19 @@ var helpArg = arg{
}
var noEtcHostsArg = arg{
description: "Do not use the OS-provided hosts.",
description: "Deprecated. Do not use the OS-provided hosts.",
longName: "no-etc-hosts",
shortName: "",
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
effect: nil,
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
effect: func(_ options, _ string) (f effect, err error) {
log.Info(
"warning: --no-etc-hosts flag is deprecated and will be removed in the future versions",
)
return nil, nil
},
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
}
var localFrontendArg = arg{

View File

@@ -101,9 +101,13 @@ func TestParseDisableUpdate(t *testing.T) {
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
// TODO(e.burkov): Remove after v0.108.0.
func TestParseDisableMemoryOptimization(t *testing.T) {
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
o, eff, err := parse("", []string{"--no-mem-optimization"})
require.NoError(t, err)
assert.Nil(t, eff)
assert.Zero(t, o)
}
func TestParseService(t *testing.T) {
@@ -127,8 +131,6 @@ func TestParseUnknown(t *testing.T) {
}
func TestSerialize(t *testing.T) {
const reportFmt = "expected %s but got %s"
testCases := []struct {
name string
opts options
@@ -173,19 +175,14 @@ func TestSerialize(t *testing.T) {
name: "glinet_mode",
opts: options{glinetMode: true},
ss: []string{"--glinet"},
}, {
name: "disable_mem_opt",
opts: options{disableMemoryOptimization: true},
ss: []string{"--no-mem-optimization"},
}, {
name: "multiple",
opts: options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
},
ss: []string{
"-c", "config",
@@ -193,18 +190,13 @@ func TestSerialize(t *testing.T) {
"-s", "run",
"--pidfile", "pid",
"--no-check-update",
"--no-mem-optimization",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := serialize(tc.opts)
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
for i, r := range result {
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
}
assert.ElementsMatch(t, tc.ss, result)
})
}
}

View File

@@ -16,18 +16,17 @@ type RDNS struct {
exchanger dnsforward.RDNSExchanger
clients *clientsContainer
// usePrivate is used to store the state of current private RDNS
// resolving settings and to react to it's changes.
// usePrivate is used to store the state of current private RDNS resolving
// settings and to react to it's changes.
usePrivate uint32
// ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan net.IP
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients
// the address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some
// time.
// address stays here while it's inside clients. After leaving clients the
// address will be resolved once again. If the address couldn't be
// resolved, cache prevents further attempts to resolve it for some time.
ipCache cache.Cache
}
@@ -125,14 +124,12 @@ func (r *RDNS) workerLoop() {
log.Debug("rdns: resolving %q: %s", ip, err)
continue
}
if host == "" {
} else if host == "" {
continue
}
// Don't handle any errors since AddHost doesn't return non-nil
// errors for now.
// Don't handle any errors since AddHost doesn't return non-nil errors
// for now.
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

View File

@@ -3,15 +3,16 @@ package home
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
@@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
rdns := &RDNS{
ipCache: ipCache,
exchanger: &rDNSExchanger{},
ipCache: ipCache,
exchanger: &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
},
clients: &clientsContainer{
list: map[string]*Client{},
idIndex: tc.cliIDIndex,
@@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
type rDNSExchanger struct {
ex aghtest.Exchanger
ex upstream.Upstream
usePrivate bool
}
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
rev, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
}
req := &dns.Msg{
Question: []dns.Question{{
Name: ip.String(),
Qtype: dns.TypePTR,
Name: dns.Fqdn(rev),
Qclass: dns.ClassINET,
Qtype: dns.TypePTR,
}},
}
@@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
MaxCount: defaultRDNSCacheSize,
})
ex := &rDNSExchanger{}
ex := &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
}
rdns := &RDNS{
ipCache: ipCache,
@@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"192.168.1.1": {"local.domain"},
"2a00:1450:400c:c06::93": {"ipv6.domain"},
localIP := net.IP{192, 168, 1, 1}
revIPv4, err := netutil.IPToReversedAddr(localIP)
require.NoError(t, err)
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
},
}
errUpstream := &aghtest.TestErrUpstream{
Err: errors.Error("1234"),
}
errUpstream := aghtest.NewErrorUpstream()
testCases := []struct {
ups upstream.Upstream
@@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ups: locUpstream,
wantLog: "",
name: "all_good",
cliIP: net.IP{192, 168, 1, 1},
cliIP: localIP,
}, {
ups: errUpstream,
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`,
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error",
cliIP: net.IP{192, 168, 1, 2},
}, {
@@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ch := make(chan net.IP)
rdns := &RDNS{
exchanger: &rDNSExchanger{
ex: aghtest.Exchanger{
Ups: tc.ups,
},
ex: tc.ups,
},
clients: cc,
ipCh: ch,

View File

@@ -433,8 +433,11 @@ EnvironmentFile=-/etc/sysconfig/{{.Name}}
WantedBy=multi-user.target
`
// Note: we should keep it in sync with the template from service_sysv_linux.go file
// Use "ps | grep -v grep | grep $(get_pid)" because "ps PID" may not work on OpenWrt
// sysvScript is the source of the daemon script for SysV-based Linux systems.
// Keep as close as possible to the https://github.com/kardianos/service/blob/29f8c79c511bc18422bb99992779f96e6bc33921/service_sysv_linux.go#L187.
//
// Use ps command instead of reading the procfs since it's a more
// implementation-independent approach.
const sysvScript = `#!/bin/sh
# For RedHat and cousins:
# chkconfig: - 99 01
@@ -465,7 +468,7 @@ get_pid() {
}
is_running() {
[ -f "$pid_file" ] && ps | grep -v grep | grep $(get_pid) > /dev/null 2>&1
[ -f "$pid_file" ] && ps -p "$(get_pid)" > /dev/null 2>&1
}
case "$1" in
@@ -609,7 +612,7 @@ command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirector
run_rc_command "$1"
`
const openBSDScript = `#!/bin/sh
const openBSDScript = `#!/bin/ksh
#
# $OpenBSD: {{ .SvcInfo }}

View File

@@ -0,0 +1,83 @@
//go:build linux
// +build linux
package home
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/kardianos/service"
)
func chooseSystem() {
sys := service.ChosenSystem()
// By default, package service uses the SysV system if it cannot detect
// anything other, but the update-rc.d fix should not be applied on OpenWrt,
// so exclude it explicitly.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4480 and
// https://github.com/AdguardTeam/AdGuardHome/issues/4677.
if sys.String() == "unix-systemv" && !aghos.IsOpenWrt() {
service.ChooseSystem(sysvSystem{System: sys})
}
}
// sysvSystem is a wrapper for service.System that wraps the service.Service
// while creating a new one.
//
// TODO(e.burkov): File a PR to github.com/kardianos/service.
type sysvSystem struct {
// System is expected to have an unexported type
// *service.linuxSystemService.
service.System
}
// New returns a wrapped service.Service.
func (sys sysvSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) {
s, err = sys.System.New(i, c)
if err != nil {
return s, err
}
return sysvService{
Service: s,
name: c.Name,
}, nil
}
// sysvService is a wrapper for a service.Service that also calls update-rc.d in
// a proper way on installing and uninstalling.
type sysvService struct {
// Service is expected to have an unexported type *service.sysv.
service.Service
// name stores the name of the service to call updating script with it.
name string
}
// Install wraps service.Service.Install call with calling the updating script.
func (svc sysvService) Install() (err error) {
err = svc.Service.Install()
if err != nil {
// Don't wrap an error since it's informative enough as is.
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
// Don't wrap an error since it's informative enough as is.
return err
}
// Uninstall wraps service.Service.Uninstall call with calling the updating
// script.
func (svc sysvService) Uninstall() (err error) {
err = svc.Service.Uninstall()
if err != nil {
// Don't wrap an error since it's informative enough as is.
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
// Don't wrap an error since it's informative enough as is.
return err
}

View File

@@ -160,7 +160,7 @@ rc_cmd $1
// template returns the script template to put into rc.d.
func (s *openbsdRunComService) template() (t *template.Template) {
tf := map[string]interface{}{
tf := map[string]any{
"args": func(sl []string) string {
return `"` + strings.Join(sl, " ") + `"`
},
@@ -314,12 +314,13 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by
// something different than the service script's non-existence. Keep it
// in mind, when replace the aghos.RunCommand.
_, out, err = aghos.RunCommand(scriptPath, cmd)
var outData []byte
_, outData, err = aghos.RunCommand(scriptPath, cmd)
if errors.Is(err, os.ErrNotExist) {
return "", service.ErrNotInstalled
}
return out, err
return string(outData), err
}
// Status implements service.Service interface for *openbsdRunComService.
@@ -389,42 +390,42 @@ func newSysLogger(_ string, _ chan<- error) (service.Logger, error) {
type sysLogger struct{}
// Error implements service.Logger interface for sysLogger.
func (sysLogger) Error(v ...interface{}) error {
func (sysLogger) Error(v ...any) error {
log.Error(fmt.Sprint(v...))
return nil
}
// Warning implements service.Logger interface for sysLogger.
func (sysLogger) Warning(v ...interface{}) error {
func (sysLogger) Warning(v ...any) error {
log.Info("warning: %s", fmt.Sprint(v...))
return nil
}
// Info implements service.Logger interface for sysLogger.
func (sysLogger) Info(v ...interface{}) error {
func (sysLogger) Info(v ...any) error {
log.Info(fmt.Sprint(v...))
return nil
}
// Errorf implements service.Logger interface for sysLogger.
func (sysLogger) Errorf(format string, a ...interface{}) error {
func (sysLogger) Errorf(format string, a ...any) error {
log.Error(format, a...)
return nil
}
// Warningf implements service.Logger interface for sysLogger.
func (sysLogger) Warningf(format string, a ...interface{}) error {
func (sysLogger) Warningf(format string, a ...any) error {
log.Info("warning: %s", fmt.Sprintf(format, a...))
return nil
}
// Infof implements service.Logger interface for sysLogger.
func (sysLogger) Infof(format string, a ...interface{}) error {
func (sysLogger) Infof(format string, a ...any) error {
log.Info(format, a...)
return nil

View File

@@ -1,6 +1,8 @@
//go:build !openbsd
// +build !openbsd
//go:build !(openbsd || linux)
// +build !openbsd,!linux
package home
// chooseSystem checks the current system detected and substitutes it with local
// implementation if needed.
func chooseSystem() {}

View File

@@ -250,21 +250,17 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
}
if setts.Enabled {
uc := aghalg.UniqChecker{}
addPorts(
uc,
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
setts.PortHTTPS,
setts.PortDNSOverTLS,
setts.PortDNSOverQUIC,
setts.PortDNSCrypt,
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
err = uc.Validate(aghalg.IntIsBefore)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "validating ports: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -343,19 +339,15 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
}
if data.Enabled {
uc := aghalg.UniqChecker{}
addPorts(
uc,
config.BindPort,
config.BetaBindPort,
config.DNS.Port,
data.PortHTTPS,
data.PortDNSOverTLS,
data.PortDNSOverQUIC,
data.PortDNSCrypt,
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(data.PortHTTPS),
tcpPort(data.PortDNSOverTLS),
tcpPort(data.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(data.PortDNSOverQUIC),
)
err = uc.Validate(aghalg.IntIsBefore)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -421,6 +413,38 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
}
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
dnsPort, doqPort udpPort,
) (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(
tcpPorts,
tcpPort(bindPort),
tcpPort(betaBindPort),
tcpPort(dohPort),
tcpPort(dotPort),
tcpPort(dnscryptTCPPort),
)
err = tcpPorts.Validate()
if err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
}
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
err = udpPorts.Validate()
if err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
return nil
}
func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error {
log.Tracef("TLS: got certificate: %d bytes", len(certChain))

View File

@@ -1,6 +1,7 @@
package home
import (
"bytes"
"fmt"
"net/url"
"os"
@@ -17,15 +18,14 @@ import (
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
"golang.org/x/crypto/bcrypt"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
// currentSchemaVersion is the current schema version.
const currentSchemaVersion = 13
const currentSchemaVersion = 14
// These aliases are provided for convenience.
type (
any = interface{}
yarr = []any
yobj = map[any]any
)
@@ -86,6 +86,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema10to11,
upgradeSchema11to12,
upgradeSchema12to13,
upgradeSchema13to14,
}
n := 0
@@ -104,16 +105,20 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
return fmt.Errorf("unknown configuration schema version %d", oldVersion)
}
body, err := yaml.Marshal(diskConf)
buf := &bytes.Buffer{}
enc := yaml.NewEncoder(buf)
enc.SetIndent(2)
err = enc.Encode(diskConf)
if err != nil {
return fmt.Errorf("generating new config: %w", err)
}
config.fileData = body
config.fileData = buf.Bytes()
confFile := config.getConfigFilename()
err = maybe.WriteFile(confFile, body, 0o644)
err = maybe.WriteFile(confFile, config.fileData, 0o644)
if err != nil {
return fmt.Errorf("saving new config: %w", err)
return fmt.Errorf("writing new config: %w", err)
}
return nil
@@ -173,11 +178,11 @@ func upgradeSchema2to3(diskConf yobj) error {
return fmt.Errorf("no DNS configuration in config file")
}
// Convert interface{} to yobj
// Convert any to yobj
newDNSConfig := make(yobj)
switch v := dnsConfig.(type) {
case map[interface{}]interface{}:
case map[any]any:
for k, v := range v {
newDNSConfig[fmt.Sprint(k)] = v
}
@@ -213,12 +218,12 @@ func upgradeSchema3to4(diskConf yobj) error {
}
switch arr := clients.(type) {
case []interface{}:
case []any:
for i := range arr {
switch c := arr[i].(type) {
case map[interface{}]interface{}:
case map[any]any:
c["use_global_blocked_services"] = true
default:
@@ -304,11 +309,11 @@ func upgradeSchema5to6(diskConf yobj) error {
}
switch arr := clients.(type) {
case []interface{}:
case []any:
for i := range arr {
switch c := arr[i].(type) {
case map[interface{}]interface{}:
var ipVal interface{}
case map[any]any:
var ipVal any
ipVal, ok = c["ip"]
ids := []string{}
if ok {
@@ -323,7 +328,7 @@ func upgradeSchema5to6(diskConf yobj) error {
}
}
var macVal interface{}
var macVal any
macVal, ok = c["mac"]
if ok {
var mac string
@@ -374,7 +379,7 @@ func upgradeSchema6to7(diskConf yobj) error {
}
switch dhcp := dhcpVal.(type) {
case map[interface{}]interface{}:
case map[any]any:
var str string
str, ok = dhcp["gateway_ip"].(string)
if !ok {
@@ -726,7 +731,7 @@ func upgradeSchema12to13(diskConf yobj) (err error) {
var dhcp yobj
dhcp, ok = dhcpVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dhcp: %T", dnsVal)
return fmt.Errorf("unexpected type of dhcp: %T", dhcpVal)
}
const field = "local_domain_name"
@@ -737,6 +742,68 @@ func upgradeSchema12to13(diskConf yobj) (err error) {
return nil
}
// upgradeSchema13to14 performs the following changes:
//
// # BEFORE:
// 'clients':
// - 'name': 'client-name'
// # …
//
// # AFTER:
// 'clients':
// 'persistent':
// - 'name': 'client-name'
// # …
// 'runtime_sources':
// 'whois': true
// 'arp': true
// 'rdns': true
// 'dhcp': true
// 'hosts': true
//
func upgradeSchema13to14(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 13 to 14")
diskConf["schema_version"] = 14
clientsVal, ok := diskConf["clients"]
if !ok {
clientsVal = yarr{}
}
var rdnsSrc bool
if dnsVal, dok := diskConf["dns"]; dok {
var dnsSettings yobj
dnsSettings, ok = dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
var rdnsSrcVal any
rdnsSrcVal, ok = dnsSettings["resolve_clients"]
if ok {
rdnsSrc, ok = rdnsSrcVal.(bool)
if !ok {
return fmt.Errorf("unexpected type of resolve_clients: %T", rdnsSrcVal)
}
delete(dnsSettings, "resolve_clients")
}
}
diskConf["clients"] = yobj{
"persistent": clientsVal,
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: rdnsSrc,
DHCP: true,
HostsFile: true,
},
}
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging
// package.
func funcName() string {

View File

@@ -190,7 +190,7 @@ func testDiskConf(schemaVersion int) (diskConf yobj) {
return diskConf
}
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v2 would
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v3 would
// unmarshal it. In YAML, keys aren't guaranteed to always only be strings.
func testDNSConf(schemaVersion int) (dnsConf yobj) {
dnsConf = yobj{
@@ -500,7 +500,7 @@ func TestUpgradeSchema11to12(t *testing.T) {
dnsVal, ok = dns.(yobj)
require.True(t, ok)
var ivl interface{}
var ivl any
ivl, ok = dnsVal["querylog_interval"]
require.True(t, ok)
@@ -513,46 +513,129 @@ func TestUpgradeSchema11to12(t *testing.T) {
}
func TestUpgradeSchema12to13(t *testing.T) {
t.Run("no_dns", func(t *testing.T) {
conf := yobj{}
const newSchemaVer = 13
err := upgradeSchema12to13(conf)
require.NoError(t, err)
assert.Equal(t, conf["schema_version"], 13)
})
t.Run("no_dhcp", func(t *testing.T) {
conf := yobj{
"dns": yobj{},
}
err := upgradeSchema12to13(conf)
require.NoError(t, err)
assert.Equal(t, conf["schema_version"], 13)
})
t.Run("good", func(t *testing.T) {
conf := yobj{
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{},
want: yobj{"schema_version": newSchemaVer},
name: "no_dns",
}, {
in: yobj{"dns": yobj{}},
want: yobj{
"dns": yobj{},
"schema_version": newSchemaVer,
},
name: "no_dhcp",
}, {
in: yobj{
"dns": yobj{
"local_domain_name": "lan",
},
"dhcp": yobj{},
"schema_version": 12,
}
wantConf := yobj{
"schema_version": newSchemaVer - 1,
},
want: yobj{
"dns": yobj{},
"dhcp": yobj{
"local_domain_name": "lan",
},
"schema_version": 13,
}
"schema_version": newSchemaVer,
},
name: "good",
}}
err := upgradeSchema12to13(conf)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema12to13(tc.in)
require.NoError(t, err)
assert.Equal(t, wantConf, conf)
})
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema13to14(t *testing.T) {
const newSchemaVer = 14
testClient := &clientObject{
Name: "agh-client",
IDs: []string{"id1"},
UseGlobalSettings: true,
}
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
// The clients field will be added anyway.
"clients": yobj{
"persistent": yarr{},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: false,
DHCP: true,
HostsFile: true,
},
},
},
name: "no_clients",
}, {
in: yobj{
"clients": []*clientObject{testClient},
},
want: yobj{
"schema_version": newSchemaVer,
"clients": yobj{
"persistent": []*clientObject{testClient},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: false,
DHCP: true,
HostsFile: true,
},
},
},
name: "no_dns",
}, {
in: yobj{
"clients": []*clientObject{testClient},
"dns": yobj{
"resolve_clients": true,
},
},
want: yobj{
"schema_version": newSchemaVer,
"clients": yobj{
"persistent": []*clientObject{testClient},
"runtime_sources": &clientSourcesConf{
WHOIS: true,
ARP: true,
RDNS: true,
DHCP: true,
HostsFile: true,
},
},
"dns": yobj{},
},
name: "good",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema13to14(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}