Merge branch 'master' into 4990-custom-ciphers
This commit is contained in:
@@ -12,16 +12,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
|
||||
@@ -119,6 +119,8 @@ type clientsContainer struct {
|
||||
idIndex map[string]*Client // ID -> client
|
||||
|
||||
// ipToRC is the IP address to *RuntimeClient map.
|
||||
//
|
||||
// TODO(e.burkov): Use map[netip.Addr]struct{} instead.
|
||||
ipToRC *netutil.IPMap
|
||||
|
||||
lock sync.Mutex
|
||||
|
||||
@@ -2,6 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -287,10 +288,10 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
DBFilePath: "leases.db",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
Enabled: true,
|
||||
GatewayIP: net.IP{1, 2, 3, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
RangeStart: net.IP{1, 2, 3, 2},
|
||||
RangeEnd: net.IP{1, 2, 3, 10},
|
||||
GatewayIP: netip.MustParseAddr("1.2.3.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("1.2.3.2"),
|
||||
RangeEnd: netip.MustParseAddr("1.2.3.10"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
@@ -85,19 +85,28 @@ type configuration struct {
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
|
||||
Users []webUser `yaml:"users"` // Users that can access HTTP server
|
||||
// BindHost is the address for the web interface server to listen on.
|
||||
BindHost netip.Addr `yaml:"bind_host"`
|
||||
// BindPort is the port for the web interface server to listen on.
|
||||
BindPort int `yaml:"bind_port"`
|
||||
// BetaBindPort is the port for the new client's web interface server to
|
||||
// listen on.
|
||||
BetaBindPort int `yaml:"beta_bind_port"`
|
||||
|
||||
// Users are the clients capable for accessing the web interface.
|
||||
Users []webUser `yaml:"users"`
|
||||
// AuthAttempts is the maximum number of failed login attempts a user
|
||||
// can do before being blocked.
|
||||
AuthAttempts uint `yaml:"auth_attempts"`
|
||||
// AuthBlockMin is the duration, in minutes, of the block of new login
|
||||
// attempts after AuthAttempts unsuccessful login attempts.
|
||||
AuthBlockMin uint `yaml:"block_auth_min"`
|
||||
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
|
||||
AuthBlockMin uint `yaml:"block_auth_min"`
|
||||
// ProxyURL is the address of proxy server for the internal HTTP client.
|
||||
ProxyURL string `yaml:"http_proxy"`
|
||||
// Language is a two-letter ISO 639-1 language code.
|
||||
Language string `yaml:"language"`
|
||||
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
|
||||
DebugPProf bool `yaml:"debug_pprof"`
|
||||
|
||||
// TTL for a web session (in hours)
|
||||
// An active session is automatically refreshed once a day.
|
||||
@@ -112,7 +121,7 @@ type configuration struct {
|
||||
//
|
||||
// TODO(e.burkov): Move all the filtering configuration fields into the
|
||||
// only configuration subsection covering the changes with a single
|
||||
// migration.
|
||||
// migration. Also keep the blocked services in mind.
|
||||
Filters []filtering.FilterYAML `yaml:"filters"`
|
||||
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
@@ -135,18 +144,26 @@ type configuration struct {
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type dnsConfig struct {
|
||||
BindHosts []net.IP `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
BindHosts []netip.Addr `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
|
||||
// time interval for statistics (in days)
|
||||
// StatsInterval is the time interval for flushing statistics to the disk in
|
||||
// days.
|
||||
StatsInterval uint32 `yaml:"statistics_interval"`
|
||||
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
|
||||
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"` // if true, query log will be written to a file
|
||||
// QueryLogEnabled defines if the query log is enabled.
|
||||
QueryLogEnabled bool `yaml:"querylog_enabled"`
|
||||
// QueryLogFileEnabled defines, if the query log is written to the file.
|
||||
QueryLogFileEnabled bool `yaml:"querylog_file_enabled"`
|
||||
// QueryLogInterval is the interval for query log's files rotation.
|
||||
QueryLogInterval timeutil.Duration `yaml:"querylog_interval"`
|
||||
QueryLogMemSize uint32 `yaml:"querylog_size_memory"` // number of entries kept in memory before they are flushed to disk
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"` // anonymize clients' IP addresses in logs and stats
|
||||
QueryLogInterval timeutil.Duration `yaml:"querylog_interval"`
|
||||
// QueryLogMemSize is the number of entries kept in memory before they are
|
||||
// flushed to disk.
|
||||
QueryLogMemSize uint32 `yaml:"querylog_size_memory"`
|
||||
|
||||
// AnonymizeClientIP defines if clients' IP addresses should be anonymized
|
||||
// in query log and statistics.
|
||||
AnonymizeClientIP bool `yaml:"anonymize_client_ip"`
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
@@ -211,12 +228,12 @@ type tlsConfigSettings struct {
|
||||
var config = &configuration{
|
||||
BindPort: 3000,
|
||||
BetaBindPort: 0,
|
||||
BindHost: net.IP{0, 0, 0, 0},
|
||||
BindHost: netip.IPv4Unspecified(),
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
WebSessionTTLHours: 30 * 24,
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []net.IP{{0, 0, 0, 0}},
|
||||
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
|
||||
Port: defaultPortDNS,
|
||||
StatsInterval: 1,
|
||||
QueryLogEnabled: true,
|
||||
@@ -236,6 +253,7 @@ var config = &configuration{
|
||||
},
|
||||
|
||||
TrustedProxies: []string{"127.0.0.0/8", "::1/128"},
|
||||
CacheSize: 4 * 1024 * 1024,
|
||||
|
||||
// set default maximum concurrent queries to 300
|
||||
// we introduced a default limit due to this:
|
||||
|
||||
@@ -2,8 +2,8 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
|
||||
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS
|
||||
// addresses to a slice of strings.
|
||||
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
|
||||
func appendDNSAddrs(dst []string, addrs ...netip.Addr) (res []string) {
|
||||
for _, addr := range addrs {
|
||||
var hostport string
|
||||
if config.DNS.Port != defaultPortDNS {
|
||||
hostport = netutil.JoinHostPort(addr.String(), config.DNS.Port)
|
||||
hostport = netip.AddrPortFrom(addr, uint16(config.DNS.Port)).String()
|
||||
} else {
|
||||
hostport = addr.String()
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
|
||||
// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
|
||||
// dst. It also adds the IP addresses of all network interfaces if src contains
|
||||
// an unspecified IP address.
|
||||
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) {
|
||||
func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err error) {
|
||||
ifacesAdded := false
|
||||
for _, h := range src {
|
||||
if !h.IsUnspecified() {
|
||||
@@ -71,7 +71,9 @@ func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err err
|
||||
// on, including the addresses on all interfaces in cases of unspecified IPs.
|
||||
func collectDNSAddresses() (addrs []string, err error) {
|
||||
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
|
||||
addrs = appendDNSAddrs(addrs, net.IP{127, 0, 0, 1})
|
||||
addr := aghnet.IPv4Localhost()
|
||||
|
||||
addrs = appendDNSAddrs(addrs, addr)
|
||||
} else {
|
||||
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
|
||||
if err != nil {
|
||||
@@ -320,6 +322,28 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
var serveHTTP3 bool
|
||||
var portHTTPS int
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS
|
||||
}()
|
||||
|
||||
respHdr := w.Header()
|
||||
|
||||
// Let the browser know that server supports HTTP/3.
|
||||
//
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Alt-Svc.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding a configurable max-age. Currently, the
|
||||
// default is 24 hours.
|
||||
if serveHTTP3 {
|
||||
altSvc := fmt.Sprintf(`h3=":%d"`, portHTTPS)
|
||||
respHdr.Set(aghhttp.HdrNameAltSvc, altSvc)
|
||||
}
|
||||
|
||||
if r.TLS == nil && web.forceHTTPS {
|
||||
hostPort := host
|
||||
if port := web.conf.PortHTTPS; port != defaultPortHTTPS {
|
||||
@@ -346,8 +370,9 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Host: r.Host,
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
||||
w.Header().Set("Vary", "Origin")
|
||||
|
||||
respHdr.Set(aghhttp.HdrNameAccessControlAllowOrigin, originURL.String())
|
||||
respHdr.Set(aghhttp.HdrNameVary, aghhttp.HdrNameOrigin)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -64,9 +64,9 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
IP net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
type checkConfReq struct {
|
||||
@@ -117,7 +117,7 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
@@ -142,13 +142,13 @@ func (req *checkConfReq) validateDNS(
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, port)
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
if !aghnet.IsAddrInUse(err) {
|
||||
return false, err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (req *checkConfReq) validateDNS(
|
||||
log.Error("disabling DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, port)
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
// handleStaticIP - handles static IP request
|
||||
// It either checks if we have a static IP
|
||||
// Or if set=true, it tries to set it
|
||||
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
|
||||
func handleStaticIP(ip netip.Addr, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := aghnet.InterfaceByIP(ip)
|
||||
@@ -304,8 +304,8 @@ func disableDNSStubListener() error {
|
||||
}
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
type applyConfigReq struct {
|
||||
@@ -397,14 +397,14 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port)
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port)
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -417,14 +417,14 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
Context.firstRun = false
|
||||
config.BindHost = req.Web.IP
|
||||
config.BindPort = req.Web.Port
|
||||
config.DNS.BindHosts = []net.IP{req.DNS.IP}
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = StartMods()
|
||||
err = startMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
@@ -490,9 +490,9 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
return nil, false, errors.Error("ports cannot be 0")
|
||||
}
|
||||
|
||||
restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port
|
||||
restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port)
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf(
|
||||
"checking address %s:%d: %w",
|
||||
@@ -518,9 +518,9 @@ func (web *Web) registerInstallHandlers() {
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default checkConfigReqEnt.
|
||||
type checkConfigReqEntBeta struct {
|
||||
IP []net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
IP []netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
// checkConfigReqBeta is a struct representing new client's config check request
|
||||
@@ -590,8 +590,8 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default applyConfigReqEnt.
|
||||
type applyConfigReqEntBeta struct {
|
||||
IP []net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
IP []netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// applyConfigReqBeta is a struct representing new client's config setting
|
||||
|
||||
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -164,33 +165,27 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
}
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []net.IP, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, len(ips))
|
||||
for i, ip := range ips {
|
||||
tcpAddrs[i] = &net.TCPAddr{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
}
|
||||
tcpAddrs = make([]*net.TCPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
func ipsToUDPAddrs(ips []net.IP, port int) (udpAddrs []*net.UDPAddr) {
|
||||
func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, len(ips))
|
||||
for i, ip := range ips {
|
||||
udpAddrs[i] = &net.UDPAddr{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
}
|
||||
udpAddrs = make([]*net.UDPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
@@ -200,7 +195,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
dnsConf := config.DNS
|
||||
hosts := dnsConf.BindHosts
|
||||
if len(hosts) == 0 {
|
||||
hosts = []net.IP{{127, 0, 0, 1}}
|
||||
hosts = []netip.Addr{aghnet.IPv4Localhost()}
|
||||
}
|
||||
|
||||
newConf = dnsforward.ServerConfig{
|
||||
@@ -257,7 +252,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
return newConf, nil
|
||||
}
|
||||
|
||||
func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||
if tlsConf.DNSCryptConfigFile == "" {
|
||||
return dnscc, errors.Error("no dnscrypt_config_file")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -58,7 +59,7 @@ type homeContext struct {
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
tls *tlsManager // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
// (e.g. /etc/hosts) files.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
@@ -97,9 +98,15 @@ var Context homeContext
|
||||
|
||||
// Main is the entry point
|
||||
func Main(clientBuildFS fs.FS) {
|
||||
// config can be specified, which reads options from there, but other command line flags have to override config values
|
||||
// therefore, we must do it manually instead of using a lib
|
||||
args := loadOptions()
|
||||
initCmdLineOpts()
|
||||
|
||||
// The configuration file path can be overridden, but other command-line
|
||||
// options have to override config values. Therefore, do it manually
|
||||
// instead of using package flag.
|
||||
//
|
||||
// TODO(a.garipov): The comment above is most likely false. Replace with
|
||||
// package flag.
|
||||
opts := loadCmdLineOpts()
|
||||
|
||||
Context.appSignalChannel = make(chan os.Signal)
|
||||
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
|
||||
@@ -110,7 +117,7 @@ func Main(clientBuildFS fs.FS) {
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
Context.clients.Reload()
|
||||
Context.tls.Reload()
|
||||
Context.tls.reload()
|
||||
|
||||
default:
|
||||
cleanup(context.Background())
|
||||
@@ -120,26 +127,18 @@ func Main(clientBuildFS fs.FS) {
|
||||
}
|
||||
}()
|
||||
|
||||
if args.serviceControlAction != "" {
|
||||
handleServiceControlAction(args, clientBuildFS)
|
||||
if opts.serviceControlAction != "" {
|
||||
handleServiceControlAction(opts, clientBuildFS)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// run the protection
|
||||
run(args, clientBuildFS)
|
||||
run(opts, clientBuildFS)
|
||||
}
|
||||
|
||||
func setupContext(args options) {
|
||||
Context.runningAsService = args.runningAsService
|
||||
Context.disableUpdate = args.disableUpdate ||
|
||||
version.Channel() == version.ChannelDevelopment
|
||||
|
||||
Context.firstRun = detectFirstRun()
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
}
|
||||
func setupContext(opts options) {
|
||||
setupContextFlags(opts)
|
||||
|
||||
switch version.Channel() {
|
||||
case version.ChannelEdge, version.ChannelDevelopment:
|
||||
@@ -148,7 +147,7 @@ func setupContext(args options) {
|
||||
// Go on.
|
||||
}
|
||||
|
||||
Context.tlsRoots = LoadSystemRootCAs()
|
||||
Context.tlsRoots = aghtls.SystemRootCAs()
|
||||
Context.transport = &http.Transport{
|
||||
DialContext: customDialContext,
|
||||
Proxy: getHTTPProxy,
|
||||
@@ -174,13 +173,13 @@ func setupContext(args options) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if args.checkConfig {
|
||||
if opts.checkConfig {
|
||||
log.Info("configuration file is ok")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !args.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
|
||||
err = setupHostsContainer()
|
||||
fatalOnError(err)
|
||||
}
|
||||
@@ -189,6 +188,24 @@ func setupContext(args options) {
|
||||
Context.mux = http.NewServeMux()
|
||||
}
|
||||
|
||||
// setupContextFlags sets global flags and prints their status to the log.
|
||||
func setupContextFlags(opts options) {
|
||||
Context.firstRun = detectFirstRun()
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
}
|
||||
|
||||
Context.runningAsService = opts.runningAsService
|
||||
// Don't print the runningAsService flag, since that has already been done
|
||||
// in [run].
|
||||
|
||||
Context.disableUpdate = opts.disableUpdate || version.Channel() == version.ChannelDevelopment
|
||||
if Context.disableUpdate {
|
||||
log.Info("AdGuard Home updates are disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// logIfUnsupported logs a formatted warning if the error is one of the
|
||||
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns
|
||||
// nil. Otherwise, it returns err.
|
||||
@@ -270,7 +287,7 @@ func setupHostsContainer() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(args options) (err error) {
|
||||
func setupConfig(opts options) (err error) {
|
||||
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
|
||||
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
|
||||
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
|
||||
@@ -312,9 +329,9 @@ func setupConfig(args options) (err error) {
|
||||
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
|
||||
if args.bindPort != 0 {
|
||||
if opts.bindPort != 0 {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
|
||||
addPorts(tcpPorts, tcpPort(opts.bindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
@@ -336,23 +353,23 @@ func setupConfig(args options) (err error) {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = args.bindPort
|
||||
config.BindPort = opts.bindPort
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
if args.bindHost != nil {
|
||||
config.BindHost = args.bindHost
|
||||
if opts.bindHost.IsValid() {
|
||||
config.BindHost = opts.bindHost
|
||||
}
|
||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||
Context.pidFileName = args.pidFile
|
||||
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
|
||||
Context.pidFileName = opts.pidFile
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(args options, clientBuildFS fs.FS) (web *Web, err error) {
|
||||
func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
|
||||
var clientFS, clientBetaFS fs.FS
|
||||
if args.localFrontend {
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
|
||||
clientFS = os.DirFS("build/static")
|
||||
@@ -406,24 +423,24 @@ func fatalOnError(err error) {
|
||||
}
|
||||
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(args options, clientBuildFS fs.FS) {
|
||||
func run(opts options, clientBuildFS fs.FS) {
|
||||
// configure config filename
|
||||
initConfigFilename(args)
|
||||
initConfigFilename(opts)
|
||||
|
||||
// configure working dir and config path
|
||||
initWorkingDir(args)
|
||||
initWorkingDir(opts)
|
||||
|
||||
// configure log level and output
|
||||
configureLogger(args)
|
||||
configureLogger(opts)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Info(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
if args.runningAsService {
|
||||
if opts.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
}
|
||||
|
||||
setupContext(args)
|
||||
setupContext(opts)
|
||||
|
||||
err := configureOS(config)
|
||||
fatalOnError(err)
|
||||
@@ -433,7 +450,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
// but also avoid relying on automatic Go init() function
|
||||
filtering.InitModule()
|
||||
|
||||
err = setupConfig(args)
|
||||
err = setupConfig(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
@@ -462,7 +479,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
}
|
||||
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
GLMode = args.glinetMode
|
||||
GLMode = opts.glinetMode
|
||||
var rateLimiter *authRateLimiter
|
||||
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
|
||||
rateLimiter = newAuthRateLimiter(
|
||||
@@ -484,19 +501,19 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
}
|
||||
config.Users = nil
|
||||
|
||||
Context.tls = tlsCreate(config.TLS)
|
||||
if Context.tls == nil {
|
||||
log.Fatalf("Can't initialize TLS module")
|
||||
Context.tls, err = newTLSManager(config.TLS)
|
||||
if err != nil {
|
||||
log.Fatalf("initializing tls: %s", err)
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(args, clientBuildFS)
|
||||
Context.web, err = initWeb(opts, clientBuildFS)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNSServer()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.Start()
|
||||
Context.tls.start()
|
||||
|
||||
go func() {
|
||||
serr := startDNSServer()
|
||||
@@ -520,20 +537,22 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
select {}
|
||||
}
|
||||
|
||||
// StartMods initializes and starts the DNS server after installation.
|
||||
func StartMods() error {
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
func startMods() error {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Context.tls.Start()
|
||||
Context.tls.start()
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -546,7 +565,7 @@ func checkPermissions() {
|
||||
}
|
||||
|
||||
// We should check if AdGuard Home is able to bind to port 53
|
||||
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
|
||||
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
log.Fatal(`Permission check failed.
|
||||
@@ -581,10 +600,10 @@ func writePIDFile(fn string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func initConfigFilename(args options) {
|
||||
func initConfigFilename(opts options) {
|
||||
// config file path can be overridden by command-line arguments:
|
||||
if args.configFilename != "" {
|
||||
Context.configFilename = args.configFilename
|
||||
if opts.confFilename != "" {
|
||||
Context.configFilename = opts.confFilename
|
||||
} else {
|
||||
// Default config file name
|
||||
Context.configFilename = "AdGuardHome.yaml"
|
||||
@@ -593,15 +612,15 @@ func initConfigFilename(args options) {
|
||||
|
||||
// initWorkingDir initializes the workDir
|
||||
// if no command-line arguments specified, we use the directory where our binary file is located
|
||||
func initWorkingDir(args options) {
|
||||
func initWorkingDir(opts options) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if args.workDir != "" {
|
||||
if opts.workDir != "" {
|
||||
// If there is a custom config file, use it's directory as our working dir
|
||||
Context.workDir = args.workDir
|
||||
Context.workDir = opts.workDir
|
||||
} else {
|
||||
Context.workDir = filepath.Dir(execPath)
|
||||
}
|
||||
@@ -615,15 +634,15 @@ func initWorkingDir(args options) {
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output
|
||||
func configureLogger(args options) {
|
||||
func configureLogger(opts options) {
|
||||
ls := getLogSettings()
|
||||
|
||||
// command-line arguments can override config settings
|
||||
if args.verbose || config.Verbose {
|
||||
if opts.verbose || config.Verbose {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.File = args.logFile
|
||||
if opts.logFile != "" {
|
||||
ls.File = opts.logFile
|
||||
} else if config.File != "" {
|
||||
ls.File = config.File
|
||||
}
|
||||
@@ -644,7 +663,7 @@ func configureLogger(args options) {
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing
|
||||
// else is configured. Otherwise, we'll simply lose the log output.
|
||||
ls.File = configSyslog
|
||||
@@ -717,7 +736,6 @@ func cleanup(ctx context.Context) {
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
Context.tls.Close()
|
||||
Context.tls = nil
|
||||
}
|
||||
}
|
||||
@@ -727,32 +745,37 @@ func cleanupAlways() {
|
||||
if len(Context.pidFileName) != 0 {
|
||||
_ = os.Remove(Context.pidFileName)
|
||||
}
|
||||
log.Info("Stopped")
|
||||
|
||||
log.Info("stopped")
|
||||
}
|
||||
|
||||
func exitWithError() {
|
||||
os.Exit(64)
|
||||
}
|
||||
|
||||
// loadOptions reads command line arguments and initializes configuration
|
||||
func loadOptions() options {
|
||||
o, f, err := parse(os.Args[0], os.Args[1:])
|
||||
|
||||
// loadCmdLineOpts reads command line arguments and initializes configuration
|
||||
// from them. If there is an error or an effect, loadCmdLineOpts processes them
|
||||
// and exits.
|
||||
func loadCmdLineOpts() (opts options) {
|
||||
opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:])
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
_ = printHelp(os.Args[0])
|
||||
printHelp(os.Args[0])
|
||||
|
||||
exitWithError()
|
||||
} else if f != nil {
|
||||
err = f()
|
||||
}
|
||||
|
||||
if eff != nil {
|
||||
err = eff()
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
exitWithError()
|
||||
} else {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
return o
|
||||
return opts
|
||||
}
|
||||
|
||||
// printWebAddrs prints addresses built from proto, addr, and an appropriate
|
||||
@@ -901,6 +924,6 @@ func getTLSCiphers() (cipherIds []uint16, err error) {
|
||||
return aghtls.SaferCipherSuites(), nil
|
||||
} else {
|
||||
log.Info("Overriding TLS Ciphers : %s", config.TLS.OverrideTLSCiphers)
|
||||
return aghtls.ParseCipherIDs(config.TLS.OverrideTLSCiphers)
|
||||
return aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||
}
|
||||
}
|
||||
|
||||
12
internal/home/home_test.go
Normal file
12
internal/home/home_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
initCmdLineOpts()
|
||||
}
|
||||
@@ -3,12 +3,11 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"howett.net/plist"
|
||||
@@ -28,12 +27,12 @@ func setupDNSIPs(t testing.TB) {
|
||||
|
||||
config = &configuration{
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []net.IP{netutil.IPv4Zero()},
|
||||
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
|
||||
Port: defaultPortDNS,
|
||||
},
|
||||
}
|
||||
|
||||
Context.tls = &TLSMod{}
|
||||
Context.tls = &tlsManager{}
|
||||
}
|
||||
|
||||
func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
@@ -66,7 +65,7 @@ func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -138,7 +137,7 @@ func TestHandleMobileConfigDoT(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2,33 +2,63 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// options passed from command-line arguments
|
||||
type options struct {
|
||||
verbose bool // is verbose logging enabled
|
||||
configFilename string // path to the config file
|
||||
workDir string // path to the working directory where we will store the filters data and the querylog
|
||||
bindHost net.IP // host address to bind HTTP server on
|
||||
bindPort int // port to serve HTTP pages on
|
||||
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
pidFile string // File name to save PID to
|
||||
checkConfig bool // Check configuration and exit
|
||||
disableUpdate bool // If set, don't check for updates
|
||||
// TODO(a.garipov): Replace with package flag.
|
||||
|
||||
// service control action (see service.ControlAction array + "status" command)
|
||||
// options represents the command-line options.
|
||||
type options struct {
|
||||
// confFilename is the path to the configuration file.
|
||||
confFilename string
|
||||
|
||||
// workDir is the path to the working directory where AdGuard Home stores
|
||||
// filter data, the query log, and other data.
|
||||
workDir string
|
||||
|
||||
// logFile is the path to the log file. If empty, AdGuard Home writes to
|
||||
// stdout; if "syslog", to syslog.
|
||||
logFile string
|
||||
|
||||
// pidFile is the file name for the file to which the PID is saved.
|
||||
pidFile string
|
||||
|
||||
// serviceControlAction is the service action to perform. See
|
||||
// [service.ControlAction] and [handleServiceControlAction].
|
||||
serviceControlAction string
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the service runner
|
||||
// bindHost is the address on which to serve the HTTP UI.
|
||||
bindHost netip.Addr
|
||||
|
||||
// bindPort is the port on which to serve the HTTP UI.
|
||||
bindPort int
|
||||
|
||||
// checkConfig is true if the current invocation is only required to check
|
||||
// the configuration file and exit.
|
||||
checkConfig bool
|
||||
|
||||
// disableUpdate, if set, makes AdGuard Home not check for updates.
|
||||
disableUpdate bool
|
||||
|
||||
// verbose shows if verbose logging is enabled.
|
||||
verbose bool
|
||||
|
||||
// runningAsService flag is set to true when options are passed from the
|
||||
// service runner
|
||||
//
|
||||
// TODO(a.garipov): Perhaps this could be determined by a non-empty
|
||||
// serviceControlAction?
|
||||
runningAsService bool
|
||||
|
||||
glinetMode bool // Activate GL-Inet compatibility mode
|
||||
// glinetMode shows if the GL-Inet compatibility mode is enabled.
|
||||
glinetMode bool
|
||||
|
||||
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be
|
||||
// used.
|
||||
@@ -39,88 +69,86 @@ type options struct {
|
||||
localFrontend bool
|
||||
}
|
||||
|
||||
// functions used for their side-effects
|
||||
type effect func() error
|
||||
|
||||
type arg struct {
|
||||
description string // a short, English description of the argument
|
||||
longName string // the name of the argument used after '--'
|
||||
shortName string // the name of the argument used after '-'
|
||||
|
||||
// only one of updateWithValue, updateNoValue, and effect should be present
|
||||
|
||||
updateWithValue func(o options, v string) (options, error) // the mutator for arguments with parameters
|
||||
updateNoValue func(o options) (options, error) // the mutator for arguments without parameters
|
||||
effect func(o options, exec string) (f effect, err error) // the side-effect closure generator
|
||||
|
||||
serialize func(o options) []string // the re-serialization function back to arguments (return nil for omit)
|
||||
// initCmdLineOpts completes initialization of the global command-line option
|
||||
// slice. It must only be called once.
|
||||
func initCmdLineOpts() {
|
||||
// The --help option cannot be put directly into cmdLineOpts, because that
|
||||
// causes initialization cycle due to printHelp referencing cmdLineOpts.
|
||||
cmdLineOpts = append(cmdLineOpts, cmdLineOpt{
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
effect: func(o options, exec string) (effect, error) {
|
||||
return func() error { printHelp(exec); exitWithError(); return nil }, nil
|
||||
},
|
||||
serialize: func(o options) (val string, ok bool) { return "", false },
|
||||
description: "Print this help.",
|
||||
longName: "help",
|
||||
shortName: "",
|
||||
})
|
||||
}
|
||||
|
||||
// {type}SliceOrNil functions check their parameter of type {type}
|
||||
// against its zero value and return nil if the parameter value is
|
||||
// zero otherwise they return a string slice of the parameter
|
||||
// effect is the type for functions used for their side-effects.
|
||||
type effect func() (err error)
|
||||
|
||||
func ipSliceOrNil(ip net.IP) []string {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
// cmdLineOpt contains the data for a single command-line option. Only one of
|
||||
// updateWithValue, updateNoValue, and effect must be present.
|
||||
type cmdLineOpt struct {
|
||||
updateWithValue func(o options, v string) (updated options, err error)
|
||||
updateNoValue func(o options) (updated options, err error)
|
||||
effect func(o options, exec string) (eff effect, err error)
|
||||
|
||||
return []string{ip.String()}
|
||||
// serialize is a function that encodes the option into a slice of
|
||||
// command-line arguments, if necessary. If ok is false, this option should
|
||||
// be skipped.
|
||||
serialize func(o options) (val string, ok bool)
|
||||
|
||||
description string
|
||||
longName string
|
||||
shortName string
|
||||
}
|
||||
|
||||
func stringSliceOrNil(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
// cmdLineOpts are all command-line options of AdGuard Home.
|
||||
var cmdLineOpts = []cmdLineOpt{{
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
o.confFilename = v
|
||||
return o, nil
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
return o.confFilename, o.confFilename != ""
|
||||
},
|
||||
description: "Path to the config file.",
|
||||
longName: "config",
|
||||
shortName: "c",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) { o.workDir = v; return o, nil },
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return o.workDir, o.workDir != "" },
|
||||
description: "Path to the working directory.",
|
||||
longName: "work-dir",
|
||||
shortName: "w",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (oo options, err error) {
|
||||
o.bindHost, err = netip.ParseAddr(v)
|
||||
|
||||
return []string{s}
|
||||
}
|
||||
return o, err
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
if !o.bindHost.IsValid() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func intSliceOrNil(i int) []string {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []string{strconv.Itoa(i)}
|
||||
}
|
||||
|
||||
func boolSliceOrNil(b bool) []string {
|
||||
if b {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var args []arg
|
||||
|
||||
var configArg = arg{
|
||||
"Path to the config file.",
|
||||
"config", "c",
|
||||
func(o options, v string) (options, error) { o.configFilename = v; return o, nil },
|
||||
nil,
|
||||
nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.configFilename) },
|
||||
}
|
||||
|
||||
var workDirArg = arg{
|
||||
"Path to the working directory.",
|
||||
"work-dir", "w",
|
||||
func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.workDir) },
|
||||
}
|
||||
|
||||
var hostArg = arg{
|
||||
"Host address to bind HTTP server on.",
|
||||
"host", "h",
|
||||
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
|
||||
func(o options) []string { return ipSliceOrNil(o.bindHost) },
|
||||
}
|
||||
|
||||
var portArg = arg{
|
||||
"Port to serve HTTP pages on.",
|
||||
"port", "p",
|
||||
func(o options, v string) (options, error) {
|
||||
return o.bindHost.String(), true
|
||||
},
|
||||
description: "Host address to bind HTTP server on.",
|
||||
longName: "host",
|
||||
shortName: "h",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
var err error
|
||||
var p int
|
||||
minPort, maxPort := 0, 1<<16-1
|
||||
@@ -131,108 +159,81 @@ var portArg = arg{
|
||||
} else {
|
||||
o.bindPort = p
|
||||
}
|
||||
return o, err
|
||||
}, nil, nil,
|
||||
func(o options) []string { return intSliceOrNil(o.bindPort) },
|
||||
}
|
||||
|
||||
var serviceArg = arg{
|
||||
"Service control action: status, install, uninstall, start, stop, restart, reload (configuration).",
|
||||
"service", "s",
|
||||
func(o options, v string) (options, error) {
|
||||
return o, err
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
if o.bindPort == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return strconv.Itoa(o.bindPort), true
|
||||
},
|
||||
description: "Port to serve HTTP pages on.",
|
||||
longName: "port",
|
||||
shortName: "p",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
o.serviceControlAction = v
|
||||
return o, nil
|
||||
}, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.serviceControlAction) },
|
||||
}
|
||||
|
||||
var logfileArg = arg{
|
||||
"Path to log file. If empty: write to stdout; if 'syslog': write to system log.",
|
||||
"logfile", "l",
|
||||
func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.logFile) },
|
||||
}
|
||||
|
||||
var pidfileArg = arg{
|
||||
"Path to a file where PID is stored.",
|
||||
"pidfile", "",
|
||||
func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.pidFile) },
|
||||
}
|
||||
|
||||
var checkConfigArg = arg{
|
||||
"Check configuration and exit.",
|
||||
"check-config", "",
|
||||
nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.checkConfig) },
|
||||
}
|
||||
|
||||
var noCheckUpdateArg = arg{
|
||||
"Don't check for updates.",
|
||||
"no-check-update", "",
|
||||
nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.disableUpdate) },
|
||||
}
|
||||
|
||||
var disableMemoryOptimizationArg = arg{
|
||||
"Deprecated. Disable memory optimization.",
|
||||
"no-mem-optimization", "",
|
||||
nil, nil, func(_ options, _ string) (f effect, err error) {
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
return o.serviceControlAction, o.serviceControlAction != ""
|
||||
},
|
||||
description: `Service control action: status, install (as a service), ` +
|
||||
`uninstall (as a service), start, stop, restart, reload (configuration).`,
|
||||
longName: "service",
|
||||
shortName: "s",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) { o.logFile = v; return o, nil },
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return o.logFile, o.logFile != "" },
|
||||
description: `Path to log file. If empty, write to stdout; ` +
|
||||
`if "syslog", write to system log.`,
|
||||
longName: "logfile",
|
||||
shortName: "l",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) { o.pidFile = v; return o, nil },
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return o.pidFile, o.pidFile != "" },
|
||||
description: "Path to a file where PID is stored.",
|
||||
longName: "pidfile",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.checkConfig = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.checkConfig },
|
||||
description: "Check configuration and exit.",
|
||||
longName: "check-config",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.disableUpdate = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.disableUpdate },
|
||||
description: "Don't check for updates.",
|
||||
longName: "no-check-update",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
effect: func(_ options, _ string) (f effect, err error) {
|
||||
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var verboseArg = arg{
|
||||
"Enable verbose output.",
|
||||
"verbose", "v",
|
||||
nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.verbose) },
|
||||
}
|
||||
|
||||
var glinetArg = arg{
|
||||
"Run in GL-Inet compatibility mode.",
|
||||
"glinet", "",
|
||||
nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil,
|
||||
func(o options) []string { return boolSliceOrNil(o.glinetMode) },
|
||||
}
|
||||
|
||||
var versionArg = arg{
|
||||
description: "Show the version and exit. Show more detailed version description with -v.",
|
||||
longName: "version",
|
||||
shortName: "",
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
effect: func(o options, exec string) (effect, error) {
|
||||
return func() error {
|
||||
if o.verbose {
|
||||
fmt.Println(version.Verbose())
|
||||
} else {
|
||||
fmt.Println(version.Full())
|
||||
}
|
||||
os.Exit(0)
|
||||
|
||||
return nil
|
||||
}, nil
|
||||
},
|
||||
serialize: func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var helpArg = arg{
|
||||
"Print this help.",
|
||||
"help", "",
|
||||
nil, nil, func(o options, exec string) (effect, error) {
|
||||
return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
var noEtcHostsArg = arg{
|
||||
description: "Deprecated. Do not use the OS-provided hosts.",
|
||||
longName: "no-etc-hosts",
|
||||
shortName: "",
|
||||
serialize: func(o options) (val string, ok bool) { return "", false },
|
||||
description: "Deprecated. Disable memory optimization.",
|
||||
longName: "no-mem-optimization",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
|
||||
effect: func(_ options, _ string) (f effect, err error) {
|
||||
@@ -242,146 +243,216 @@ var noEtcHostsArg = arg{
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) },
|
||||
}
|
||||
|
||||
var localFrontendArg = arg{
|
||||
description: "Use local frontend directories.",
|
||||
longName: "local-frontend",
|
||||
shortName: "",
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.noEtcHosts },
|
||||
description: "Deprecated. Do not use the OS-provided hosts.",
|
||||
longName: "no-etc-hosts",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.localFrontend = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) []string { return boolSliceOrNil(o.localFrontend) },
|
||||
}
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.localFrontend },
|
||||
description: "Use local frontend directories.",
|
||||
longName: "local-frontend",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.verbose = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.verbose },
|
||||
description: "Enable verbose output.",
|
||||
longName: "verbose",
|
||||
shortName: "v",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.glinetMode = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.glinetMode },
|
||||
description: "Run in GL-Inet compatibility mode.",
|
||||
longName: "glinet",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
effect: func(o options, exec string) (effect, error) {
|
||||
return func() error {
|
||||
if o.verbose {
|
||||
fmt.Println(version.Verbose())
|
||||
} else {
|
||||
fmt.Println(version.Full())
|
||||
}
|
||||
|
||||
func init() {
|
||||
args = []arg{
|
||||
configArg,
|
||||
workDirArg,
|
||||
hostArg,
|
||||
portArg,
|
||||
serviceArg,
|
||||
logfileArg,
|
||||
pidfileArg,
|
||||
checkConfigArg,
|
||||
noCheckUpdateArg,
|
||||
disableMemoryOptimizationArg,
|
||||
noEtcHostsArg,
|
||||
localFrontendArg,
|
||||
verboseArg,
|
||||
glinetArg,
|
||||
versionArg,
|
||||
helpArg,
|
||||
}
|
||||
}
|
||||
os.Exit(0)
|
||||
|
||||
func getUsageLines(exec string, args []arg) []string {
|
||||
usage := []string{
|
||||
"Usage:",
|
||||
"",
|
||||
fmt.Sprintf("%s [options]", exec),
|
||||
"",
|
||||
"Options:",
|
||||
}
|
||||
for _, arg := range args {
|
||||
return nil
|
||||
}, nil
|
||||
},
|
||||
serialize: func(o options) (val string, ok bool) { return "", false },
|
||||
description: "Show the version and exit. Show more detailed version description with -v.",
|
||||
longName: "version",
|
||||
shortName: "",
|
||||
}}
|
||||
|
||||
// printHelp prints the entire help message. It exits with an error code if
|
||||
// there are any I/O errors.
|
||||
func printHelp(exec string) {
|
||||
b := &strings.Builder{}
|
||||
|
||||
stringutil.WriteToBuilder(
|
||||
b,
|
||||
"Usage:\n\n",
|
||||
fmt.Sprintf("%s [options]\n\n", exec),
|
||||
"Options:\n",
|
||||
)
|
||||
|
||||
var err error
|
||||
for _, opt := range cmdLineOpts {
|
||||
val := ""
|
||||
if arg.updateWithValue != nil {
|
||||
if opt.updateWithValue != nil {
|
||||
val = " VALUE"
|
||||
}
|
||||
if arg.shortName != "" {
|
||||
usage = append(usage, fmt.Sprintf(" -%s, %-30s %s",
|
||||
arg.shortName,
|
||||
"--"+arg.longName+val,
|
||||
arg.description))
|
||||
|
||||
longDesc := opt.longName + val
|
||||
if opt.shortName != "" {
|
||||
_, err = fmt.Fprintf(b, " -%s, --%-28s %s\n", opt.shortName, longDesc, opt.description)
|
||||
} else {
|
||||
usage = append(usage, fmt.Sprintf(" %-34s %s",
|
||||
"--"+arg.longName+val,
|
||||
arg.description))
|
||||
_, err = fmt.Fprintf(b, " --%-32s %s\n", longDesc, opt.description)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// The only error here can be from incorrect Fprintf usage, which is
|
||||
// a programmer error.
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
|
||||
_, err = fmt.Print(b)
|
||||
if err != nil {
|
||||
// Exit immediately, since not being able to print out a help message
|
||||
// essentially means that the I/O is very broken at the moment.
|
||||
exitWithError()
|
||||
}
|
||||
}
|
||||
|
||||
func printHelp(exec string) error {
|
||||
for _, line := range getUsageLines(exec, args) {
|
||||
_, err := fmt.Println(line)
|
||||
// parseCmdOpts parses the command-line arguments into options and effects.
|
||||
func parseCmdOpts(cmdName string, args []string) (o options, eff effect, err error) {
|
||||
// Don't use range since the loop changes the loop variable.
|
||||
argsLen := len(args)
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
isKnown := false
|
||||
for _, opt := range cmdLineOpts {
|
||||
isKnown = argMatches(opt, arg)
|
||||
if !isKnown {
|
||||
continue
|
||||
}
|
||||
|
||||
if opt.updateWithValue != nil {
|
||||
i++
|
||||
if i >= argsLen {
|
||||
return o, eff, fmt.Errorf("got %s without argument", arg)
|
||||
}
|
||||
|
||||
o, err = opt.updateWithValue(o, args[i])
|
||||
} else {
|
||||
o, eff, err = updateOptsNoValue(o, eff, opt, cmdName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return o, eff, fmt.Errorf("applying option %s: %w", arg, err)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if !isKnown {
|
||||
return o, eff, fmt.Errorf("unknown option %s", arg)
|
||||
}
|
||||
}
|
||||
|
||||
return o, eff, err
|
||||
}
|
||||
|
||||
// argMatches returns true if arg matches command-line option opt.
|
||||
func argMatches(opt cmdLineOpt, arg string) (ok bool) {
|
||||
if arg == "" || arg[0] != '-' {
|
||||
return false
|
||||
}
|
||||
|
||||
arg = arg[1:]
|
||||
if arg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return (opt.shortName != "" && arg == opt.shortName) ||
|
||||
(arg[0] == '-' && arg[1:] == opt.longName)
|
||||
}
|
||||
|
||||
// updateOptsNoValue sets values or effects from opt into o or prev.
|
||||
func updateOptsNoValue(
|
||||
o options,
|
||||
prev effect,
|
||||
opt cmdLineOpt,
|
||||
cmdName string,
|
||||
) (updated options, chained effect, err error) {
|
||||
if opt.updateNoValue != nil {
|
||||
o, err = opt.updateNoValue(o)
|
||||
if err != nil {
|
||||
return o, prev, err
|
||||
}
|
||||
|
||||
return o, prev, nil
|
||||
}
|
||||
|
||||
next, err := opt.effect(o, cmdName)
|
||||
if err != nil {
|
||||
return o, prev, err
|
||||
}
|
||||
|
||||
chained = chainEffect(prev, next)
|
||||
|
||||
return o, chained, nil
|
||||
}
|
||||
|
||||
// chainEffect chans the next effect after the prev one. If prev is nil, eff
|
||||
// only calls next. If next is nil, eff is prev; if prev is nil, eff is next.
|
||||
func chainEffect(prev, next effect) (eff effect) {
|
||||
if prev == nil {
|
||||
return next
|
||||
} else if next == nil {
|
||||
return prev
|
||||
}
|
||||
|
||||
eff = func() (err error) {
|
||||
err = prev()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return next()
|
||||
}
|
||||
return nil
|
||||
|
||||
return eff
|
||||
}
|
||||
|
||||
func argMatches(a arg, v string) bool {
|
||||
return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName)
|
||||
}
|
||||
|
||||
func parse(exec string, ss []string) (o options, f effect, err error) {
|
||||
for i := 0; i < len(ss); i++ {
|
||||
v := ss[i]
|
||||
knownParam := false
|
||||
for _, arg := range args {
|
||||
if argMatches(arg, v) {
|
||||
if arg.updateWithValue != nil {
|
||||
if i+1 >= len(ss) {
|
||||
return o, f, fmt.Errorf("got %s without argument", v)
|
||||
}
|
||||
i++
|
||||
o, err = arg.updateWithValue(o, ss[i])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if arg.updateNoValue != nil {
|
||||
o, err = arg.updateNoValue(o)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if arg.effect != nil {
|
||||
var eff effect
|
||||
eff, err = arg.effect(o, exec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if eff != nil {
|
||||
prevf := f
|
||||
f = func() (ferr error) {
|
||||
if prevf != nil {
|
||||
ferr = prevf()
|
||||
}
|
||||
if ferr == nil {
|
||||
ferr = eff()
|
||||
}
|
||||
return ferr
|
||||
}
|
||||
}
|
||||
}
|
||||
knownParam = true
|
||||
break
|
||||
}
|
||||
// optsToArgs converts command line options into a list of arguments.
|
||||
func optsToArgs(o options) (args []string) {
|
||||
for _, opt := range cmdLineOpts {
|
||||
val, ok := opt.serialize(o)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !knownParam {
|
||||
return o, f, fmt.Errorf("unknown option %v", v)
|
||||
|
||||
if opt.shortName != "" {
|
||||
args = append(args, "-"+opt.shortName)
|
||||
} else {
|
||||
args = append(args, "--"+opt.longName)
|
||||
}
|
||||
|
||||
if val != "" {
|
||||
args = append(args, val)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func shortestFlag(a arg) string {
|
||||
if a.shortName != "" {
|
||||
return "-" + a.shortName
|
||||
}
|
||||
return "--" + a.longName
|
||||
}
|
||||
|
||||
func serialize(o options) []string {
|
||||
ss := []string{}
|
||||
for _, arg := range args {
|
||||
s := arg.serialize(o)
|
||||
if s != nil {
|
||||
ss = append(ss, append([]string{shortestFlag(arg)}, s...)...)
|
||||
}
|
||||
}
|
||||
return ss
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
func testParseOK(t *testing.T, ss ...string) options {
|
||||
t.Helper()
|
||||
|
||||
o, _, err := parse("", ss)
|
||||
o, _, err := parseCmdOpts("", ss)
|
||||
require.NoError(t, err)
|
||||
|
||||
return o
|
||||
@@ -21,7 +21,7 @@ func testParseOK(t *testing.T, ss ...string) options {
|
||||
func testParseErr(t *testing.T, descr string, ss ...string) {
|
||||
t.Helper()
|
||||
|
||||
_, _, err := parse("", ss)
|
||||
_, _, err := parseCmdOpts("", ss)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -38,11 +38,11 @@ func TestParseVerbose(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseConfigFilename(t *testing.T) {
|
||||
assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename")
|
||||
assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename")
|
||||
assert.Equal(t, "", testParseOK(t).confFilename, "empty is no config filename")
|
||||
assert.Equal(t, "path", testParseOK(t, "-c", "path").confFilename, "-c is config filename")
|
||||
testParseParamMissing(t, "-c")
|
||||
|
||||
assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename")
|
||||
assert.Equal(t, "path", testParseOK(t, "--config", "path").confFilename, "--config is config filename")
|
||||
testParseParamMissing(t, "--config")
|
||||
}
|
||||
|
||||
@@ -56,11 +56,13 @@ func TestParseWorkDir(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseBindHost(t *testing.T) {
|
||||
assert.Nil(t, testParseOK(t).bindHost, "empty is not host")
|
||||
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
|
||||
wantAddr := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
assert.Zero(t, testParseOK(t).bindHost, "empty is not host")
|
||||
assert.Equal(t, wantAddr, testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
|
||||
testParseParamMissing(t, "-h")
|
||||
|
||||
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
|
||||
assert.Equal(t, wantAddr, testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
|
||||
testParseParamMissing(t, "--host")
|
||||
}
|
||||
|
||||
@@ -103,7 +105,7 @@ func TestParseDisableUpdate(t *testing.T) {
|
||||
|
||||
// TODO(e.burkov): Remove after v0.108.0.
|
||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||
o, eff, err := parse("", []string{"--no-mem-optimization"})
|
||||
o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, eff)
|
||||
@@ -130,73 +132,73 @@ func TestParseUnknown(t *testing.T) {
|
||||
testParseErr(t, "unknown dash", "-")
|
||||
}
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
func TestOptsToArgs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
opts options
|
||||
ss []string
|
||||
}{{
|
||||
name: "empty",
|
||||
args: []string{},
|
||||
opts: options{},
|
||||
ss: []string{},
|
||||
}, {
|
||||
name: "config_filename",
|
||||
opts: options{configFilename: "path"},
|
||||
ss: []string{"-c", "path"},
|
||||
args: []string{"-c", "path"},
|
||||
opts: options{confFilename: "path"},
|
||||
}, {
|
||||
name: "work_dir",
|
||||
args: []string{"-w", "path"},
|
||||
opts: options{workDir: "path"},
|
||||
ss: []string{"-w", "path"},
|
||||
}, {
|
||||
name: "bind_host",
|
||||
opts: options{bindHost: net.IP{1, 2, 3, 4}},
|
||||
ss: []string{"-h", "1.2.3.4"},
|
||||
opts: options{bindHost: netip.MustParseAddr("1.2.3.4")},
|
||||
args: []string{"-h", "1.2.3.4"},
|
||||
}, {
|
||||
name: "bind_port",
|
||||
args: []string{"-p", "666"},
|
||||
opts: options{bindPort: 666},
|
||||
ss: []string{"-p", "666"},
|
||||
}, {
|
||||
name: "log_file",
|
||||
args: []string{"-l", "path"},
|
||||
opts: options{logFile: "path"},
|
||||
ss: []string{"-l", "path"},
|
||||
}, {
|
||||
name: "pid_file",
|
||||
args: []string{"--pidfile", "path"},
|
||||
opts: options{pidFile: "path"},
|
||||
ss: []string{"--pidfile", "path"},
|
||||
}, {
|
||||
name: "disable_update",
|
||||
args: []string{"--no-check-update"},
|
||||
opts: options{disableUpdate: true},
|
||||
ss: []string{"--no-check-update"},
|
||||
}, {
|
||||
name: "control_action",
|
||||
args: []string{"-s", "run"},
|
||||
opts: options{serviceControlAction: "run"},
|
||||
ss: []string{"-s", "run"},
|
||||
}, {
|
||||
name: "glinet_mode",
|
||||
args: []string{"--glinet"},
|
||||
opts: options{glinetMode: true},
|
||||
ss: []string{"--glinet"},
|
||||
}, {
|
||||
name: "multiple",
|
||||
opts: options{
|
||||
serviceControlAction: "run",
|
||||
configFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
},
|
||||
ss: []string{
|
||||
args: []string{
|
||||
"-c", "config",
|
||||
"-w", "work",
|
||||
"-s", "run",
|
||||
"--pidfile", "pid",
|
||||
"--no-check-update",
|
||||
},
|
||||
opts: options{
|
||||
serviceControlAction: "run",
|
||||
confFilename: "config",
|
||||
workDir: "work",
|
||||
pidFile: "pid",
|
||||
disableUpdate: true,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := serialize(tc.opts)
|
||||
assert.ElementsMatch(t, tc.ss, result)
|
||||
result := optsToArgs(tc.opts)
|
||||
assert.ElementsMatch(t, tc.args, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
WorkingDirectory: pwd,
|
||||
Arguments: serialize(runOpts),
|
||||
Arguments: optsToArgs(runOpts),
|
||||
}
|
||||
configureService(svcConfig)
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -28,216 +26,256 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
var tlsWebHandlersRegistered = false
|
||||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||
// encryption.
|
||||
type tlsManager struct {
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
// TLSMod - TLS module object
|
||||
type TLSMod struct {
|
||||
certLastMod time.Time // last modification time of the certificate file
|
||||
status tlsConfigStatus
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
// certLastMod is the last modification time of the certificate file.
|
||||
certLastMod time.Time
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
}
|
||||
|
||||
// Create TLS module
|
||||
func tlsCreate(conf tlsConfigSettings) *TLSMod {
|
||||
t := &TLSMod{}
|
||||
t.conf = conf
|
||||
if t.conf.Enabled {
|
||||
if !t.load() {
|
||||
// Something is not valid - return an empty TLS config
|
||||
return &TLSMod{conf: tlsConfigSettings{
|
||||
Enabled: conf.Enabled,
|
||||
ServerName: conf.ServerName,
|
||||
PortHTTPS: conf.PortHTTPS,
|
||||
PortDNSOverTLS: conf.PortDNSOverTLS,
|
||||
PortDNSOverQUIC: conf.PortDNSOverQUIC,
|
||||
AllowUnencryptedDoH: conf.AllowUnencryptedDoH,
|
||||
}}
|
||||
// newTLSManager initializes the TLS configuration.
|
||||
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
if m.conf.Enabled {
|
||||
err = m.load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.setCertFileTime()
|
||||
|
||||
m.setCertFileTime()
|
||||
}
|
||||
return t
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (t *TLSMod) load() bool {
|
||||
if !tlsLoadConfig(&t.conf, &t.status) {
|
||||
log.Error("failed to load TLS config: %s", t.status.WarningValidation)
|
||||
return false
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
func (m *tlsManager) load() (err error) {
|
||||
err = loadTLSConf(&m.conf, m.status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
|
||||
if !data.ValidPair {
|
||||
log.Error("failed to validate certificate: %s", data.WarningValidation)
|
||||
return false
|
||||
}
|
||||
t.status = data
|
||||
return true
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (t *TLSMod) Close() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
t.confLock.Lock()
|
||||
*conf = t.conf
|
||||
t.confLock.Unlock()
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
m.confLock.Lock()
|
||||
*conf = m.conf
|
||||
m.confLock.Unlock()
|
||||
}
|
||||
|
||||
func (t *TLSMod) setCertFileTime() {
|
||||
if len(t.conf.CertificatePath) == 0 {
|
||||
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
||||
// errors, setCertFileTime logs them.
|
||||
func (m *tlsManager) setCertFileTime() {
|
||||
if len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
fi, err := os.Stat(t.conf.CertificatePath)
|
||||
|
||||
fi, err := os.Stat(m.conf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
log.Error("tls: looking up certificate path: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
m.certLastMod = fi.ModTime().UTC()
|
||||
}
|
||||
|
||||
// Start updates the configuration of TLSMod and starts it.
|
||||
func (t *TLSMod) Start() {
|
||||
if !tlsWebHandlersRegistered {
|
||||
tlsWebHandlersRegistered = true
|
||||
t.registerWebHandlers()
|
||||
}
|
||||
// start updates the configuration of t and starts it.
|
||||
func (m *tlsManager) start() {
|
||||
m.registerWebHandlers()
|
||||
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps
|
||||
// context with timeout on its own and shuts down the server, which
|
||||
// handles current request.
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// Reload updates the configuration of TLSMod and restarts it.
|
||||
func (t *TLSMod) Reload() {
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
// reload updates the configuration and restarts t.
|
||||
func (m *tlsManager) reload() {
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("TLS: %s", err)
|
||||
return
|
||||
}
|
||||
if fi.ModTime().UTC().Equal(t.certLastMod) {
|
||||
log.Debug("TLS: certificate file isn't modified")
|
||||
return
|
||||
}
|
||||
log.Debug("TLS: certificate file is modified")
|
||||
log.Error("tls: %s", err)
|
||||
|
||||
t.confLock.Lock()
|
||||
r := t.load()
|
||||
t.confLock.Unlock()
|
||||
if !r {
|
||||
return
|
||||
}
|
||||
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
if fi.ModTime().UTC().Equal(m.certLastMod) {
|
||||
log.Debug("tls: certificate file isn't modified")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("tls: certificate file is modified")
|
||||
|
||||
m.confLock.Lock()
|
||||
err = m.load()
|
||||
m.confLock.Unlock()
|
||||
if err != nil {
|
||||
log.Error("tls: reloading: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
m.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
_ = reconfigureDNSServer()
|
||||
|
||||
t.confLock.Lock()
|
||||
tlsConf = t.conf
|
||||
t.confLock.Unlock()
|
||||
// The background context is used because the TLSConfigChanged wraps
|
||||
// context with timeout on its own and shuts down the server, which
|
||||
// handles current request.
|
||||
m.confLock.Lock()
|
||||
tlsConf = m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// Set certificate and private key data
|
||||
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool {
|
||||
tls.CertificateChainData = []byte(tls.CertificateChain)
|
||||
tls.PrivateKeyData = []byte(tls.PrivateKey)
|
||||
|
||||
var err error
|
||||
if tls.CertificatePath != "" {
|
||||
if tls.CertificateChain != "" {
|
||||
status.WarningValidation = "certificate data and file can't be set together"
|
||||
return false
|
||||
}
|
||||
tls.CertificateChainData, err = os.ReadFile(tls.CertificatePath)
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
// also set in status.WarningValidation.
|
||||
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
return false
|
||||
}
|
||||
}()
|
||||
|
||||
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
|
||||
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
||||
|
||||
if tlsConf.CertificatePath != "" {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
}
|
||||
|
||||
status.ValidCert = true
|
||||
}
|
||||
|
||||
if tls.PrivateKeyPath != "" {
|
||||
if tls.PrivateKey != "" {
|
||||
status.WarningValidation = "private key data and file can't be set together"
|
||||
return false
|
||||
if tlsConf.PrivateKeyPath != "" {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file can't be set together")
|
||||
}
|
||||
tls.PrivateKeyData, err = os.ReadFile(tls.PrivateKeyPath)
|
||||
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
return false
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
}
|
||||
|
||||
status.ValidKey = true
|
||||
}
|
||||
|
||||
return true
|
||||
err = validateCertificates(
|
||||
status,
|
||||
tlsConf.CertificateChainData,
|
||||
tlsConf.PrivateKeyData,
|
||||
tlsConf.ServerName,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating certificate pair: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
||||
type tlsConfigStatus struct {
|
||||
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates
|
||||
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA
|
||||
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
|
||||
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
|
||||
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
|
||||
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
|
||||
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
|
||||
// Subject is the subject of the first certificate in the chain.
|
||||
Subject string `json:"subject,omitempty"`
|
||||
|
||||
// key status
|
||||
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key
|
||||
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
|
||||
// Issuer is the issuer of the first certificate in the chain.
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
|
||||
// is usable? set by validator
|
||||
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct
|
||||
// KeyType is the type of the private key.
|
||||
KeyType string `json:"key_type,omitempty"`
|
||||
|
||||
// warnings
|
||||
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description
|
||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||
NotBefore time.Time `json:"not_before,omitempty"`
|
||||
|
||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||
NotAfter time.Time `json:"not_after,omitempty"`
|
||||
|
||||
// WarningValidation is a validation warning message with the issue
|
||||
// description.
|
||||
WarningValidation string `json:"warning_validation,omitempty"`
|
||||
|
||||
// DNSNames is the value of SubjectAltNames field of the first certificate
|
||||
// in the chain.
|
||||
DNSNames []string `json:"dns_names"`
|
||||
|
||||
// ValidCert is true if the specified certificate chain is a valid chain of
|
||||
// X509 certificates.
|
||||
ValidCert bool `json:"valid_cert"`
|
||||
|
||||
// ValidChain is true if the specified certificate chain is verified and
|
||||
// issued by a known CA.
|
||||
ValidChain bool `json:"valid_chain"`
|
||||
|
||||
// ValidKey is true if the key is a valid private key.
|
||||
ValidKey bool `json:"valid_key"`
|
||||
|
||||
// ValidPair is true if both certificate and private key are correct for
|
||||
// each other.
|
||||
ValidPair bool `json:"valid_pair"`
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// tlsConfig is the TLS configuration and status response.
|
||||
type tlsConfig struct {
|
||||
tlsConfigStatus `json:",inline"`
|
||||
*tlsConfigStatus `json:",inline"`
|
||||
tlsConfigSettingsExt `json:",inline"`
|
||||
}
|
||||
|
||||
// tlsConfigSettingsExt is used to (un)marshal PrivateKeySaved to ensure that
|
||||
// clients don't send and receive previously saved private keys.
|
||||
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
|
||||
// ensure that clients don't send and receive previously saved private keys.
|
||||
type tlsConfigSettingsExt struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
// If private key saved as a string, we set this flag to true
|
||||
// and omit key from answer.
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
t.confLock.Lock()
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
m.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: t.conf,
|
||||
tlsConfigSettings: m.conf,
|
||||
},
|
||||
tlsConfigStatus: t.status,
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
t.confLock.Unlock()
|
||||
m.confLock.Unlock()
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
@@ -246,7 +284,7 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if setts.PrivateKeySaved {
|
||||
setts.PrivateKey = t.conf.PrivateKey
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
@@ -278,75 +316,74 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if tlsLoadConfig(&setts.tlsConfigSettings, &status) {
|
||||
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName)
|
||||
}
|
||||
|
||||
data := tlsConfig{
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = loadTLSConf(&setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) {
|
||||
t.confLock.Lock()
|
||||
defer t.confLock.Unlock()
|
||||
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
//
|
||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||
newConf.DNSCryptConfigFile = t.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = t.conf.PortDNSCrypt
|
||||
if !cmp.Equal(t.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
log.Info("tls config has changed, restarting https server")
|
||||
restartHTTPS = true
|
||||
} else {
|
||||
log.Info("tls config has not changed")
|
||||
log.Info("tls: config has not changed")
|
||||
}
|
||||
|
||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
||||
t.conf.Enabled = newConf.Enabled
|
||||
t.conf.ServerName = newConf.ServerName
|
||||
t.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
t.conf.PortHTTPS = newConf.PortHTTPS
|
||||
t.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
t.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
t.conf.CertificateChain = newConf.CertificateChain
|
||||
t.conf.CertificatePath = newConf.CertificatePath
|
||||
t.conf.CertificateChainData = newConf.CertificateChainData
|
||||
t.conf.PrivateKey = newConf.PrivateKey
|
||||
t.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
t.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
t.status = status
|
||||
m.conf.Enabled = newConf.Enabled
|
||||
m.conf.ServerName = newConf.ServerName
|
||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
m.conf.CertificateChain = newConf.CertificateChain
|
||||
m.conf.CertificatePath = newConf.CertificatePath
|
||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||
m.conf.PrivateKey = newConf.PrivateKey
|
||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
m.status = status
|
||||
|
||||
return restartHTTPS
|
||||
}
|
||||
|
||||
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := unmarshalTLS(r)
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if data.PrivateKeySaved {
|
||||
data.PrivateKey = t.conf.PrivateKey
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if data.Enabled {
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(data.PortHTTPS),
|
||||
tcpPort(data.PortDNSOverTLS),
|
||||
tcpPort(data.PortDNSCrypt),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(data.PortDNSOverQUIC),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -356,33 +393,33 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||
if !webCheckPortAvailable(data.PortHTTPS) {
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable HTTPS on it",
|
||||
data.PortHTTPS,
|
||||
"port %d is not available, cannot enable https on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
status := tlsConfigStatus{}
|
||||
if !tlsLoadConfig(&data.tlsConfigSettings, &status) {
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettingsExt: data,
|
||||
tlsConfigStatus: t.status,
|
||||
status := &tlsConfigStatus{}
|
||||
err = loadTLSConf(&req.tlsConfigSettings, status)
|
||||
if err != nil {
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
marshalTLS(w, r, data2)
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName)
|
||||
|
||||
restartHTTPS := t.setConfig(data.tlsConfigSettings, status)
|
||||
t.setCertFileTime()
|
||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
|
||||
m.setCertFileTime()
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
@@ -392,12 +429,12 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
data2 := tlsConfig{
|
||||
tlsConfigSettingsExt: data,
|
||||
tlsConfigStatus: t.status,
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, data2)
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
@@ -408,7 +445,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(context.Background(), data.tlsConfigSettings)
|
||||
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -445,89 +482,105 @@ func validatePorts(
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error {
|
||||
log.Tracef("TLS: got certificate: %d bytes", len(certChain))
|
||||
// validateCertChain validates the certificate chain and sets data in status.
|
||||
// The returned error is also set in status.WarningValidation.
|
||||
func validateCertChain(status *tlsConfigStatus, certChain []byte, serverName string) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
}
|
||||
}()
|
||||
|
||||
// now do a more extended validation
|
||||
var certs []*pem.Block // PEM-encoded certificates
|
||||
log.Debug("tls: got certificate chain: %d bytes", len(certChain))
|
||||
|
||||
pemblock := []byte(certChain)
|
||||
var certs []*pem.Block
|
||||
pemblock := certChain
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
decoded, pemblock = pem.Decode(pemblock)
|
||||
if decoded == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if decoded.Type == "CERTIFICATE" {
|
||||
certs = append(certs, decoded)
|
||||
}
|
||||
}
|
||||
|
||||
var parsedCerts []*x509.Certificate
|
||||
|
||||
for _, cert := range certs {
|
||||
parsed, err := x509.ParseCertificate(cert.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
parsedCerts = append(parsedCerts, parsed)
|
||||
parsedCerts, err := parsePEMCerts(certs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(parsedCerts) == 0 {
|
||||
data.WarningValidation = "You have specified an empty certificate"
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidCert = true
|
||||
|
||||
// spew.Dump(parsedCerts)
|
||||
status.ValidCert = true
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: serverName,
|
||||
Roots: Context.tlsRoots,
|
||||
}
|
||||
|
||||
log.Printf("number of certs - %d", len(parsedCerts))
|
||||
if len(parsedCerts) > 1 {
|
||||
// set up an intermediate
|
||||
pool := x509.NewCertPool()
|
||||
for _, cert := range parsedCerts[1:] {
|
||||
log.Printf("got an intermediate cert")
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
opts.Intermediates = pool
|
||||
log.Info("tls: number of certs: %d", len(parsedCerts))
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
for _, cert := range parsedCerts[1:] {
|
||||
log.Info("tls: got an intermediate cert")
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
|
||||
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem
|
||||
opts.Intermediates = pool
|
||||
|
||||
mainCert := parsedCerts[0]
|
||||
_, err := mainCert.Verify(opts)
|
||||
_, err = mainCert.Verify(opts)
|
||||
if err != nil {
|
||||
// let self-signed certs through
|
||||
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err)
|
||||
// Let self-signed certs through and don't return this error.
|
||||
status.WarningValidation = fmt.Sprintf("certificate does not verify: %s", err)
|
||||
} else {
|
||||
data.ValidChain = true
|
||||
status.ValidChain = true
|
||||
}
|
||||
// spew.Dump(chains)
|
||||
|
||||
// update status
|
||||
if mainCert != nil {
|
||||
notAfter := mainCert.NotAfter
|
||||
data.Subject = mainCert.Subject.String()
|
||||
data.Issuer = mainCert.Issuer.String()
|
||||
data.NotAfter = notAfter
|
||||
data.NotBefore = mainCert.NotBefore
|
||||
data.DNSNames = mainCert.DNSNames
|
||||
status.Subject = mainCert.Subject.String()
|
||||
status.Issuer = mainCert.Issuer.String()
|
||||
status.NotAfter = mainCert.NotAfter
|
||||
status.NotBefore = mainCert.NotBefore
|
||||
status.DNSNames = mainCert.DNSNames
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePkey(data *tlsConfigStatus, pkey string) error {
|
||||
// now do a more extended validation
|
||||
var key *pem.Block // PEM-encoded certificates
|
||||
// parsePEMCerts parses multiple PEM-encoded certificates.
|
||||
func parsePEMCerts(certs []*pem.Block) (parsedCerts []*x509.Certificate, err error) {
|
||||
for i, cert := range certs {
|
||||
var parsed *x509.Certificate
|
||||
parsed, err = x509.ParseCertificate(cert.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing certificate at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
// go through all pem blocks, but take first valid pem block and drop the rest
|
||||
parsedCerts = append(parsedCerts, parsed)
|
||||
}
|
||||
|
||||
if len(parsedCerts) == 0 {
|
||||
return nil, errors.Error("empty certificate")
|
||||
}
|
||||
|
||||
return parsedCerts, nil
|
||||
}
|
||||
|
||||
// validatePKey validates the private key and sets data in status. The returned
|
||||
// error is also set in status.WarningValidation.
|
||||
func validatePKey(status *tlsConfigStatus, pkey []byte) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
}
|
||||
}()
|
||||
|
||||
var key *pem.Block
|
||||
|
||||
// Go through all pem blocks, but take first valid pem block and drop the
|
||||
// rest.
|
||||
pemblock := []byte(pkey)
|
||||
for {
|
||||
var decoded *pem.Block
|
||||
@@ -544,61 +597,77 @@ func validatePkey(data *tlsConfigStatus, pkey string) error {
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
data.WarningValidation = "No valid keys were found"
|
||||
|
||||
return errors.Error(data.WarningValidation)
|
||||
return errors.Error("no valid keys were found")
|
||||
}
|
||||
|
||||
// parse the decoded key
|
||||
_, keyType, err := parsePrivateKey(key.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
|
||||
|
||||
return errors.Error(data.WarningValidation)
|
||||
} else if keyType == keyTypeED25519 {
|
||||
data.WarningValidation = "ED25519 keys are not supported by browsers; " +
|
||||
"did you mean to use X25519 for key exchange?"
|
||||
|
||||
return errors.Error(data.WarningValidation)
|
||||
return fmt.Errorf("parsing private key: %w", err)
|
||||
}
|
||||
|
||||
data.ValidKey = true
|
||||
data.KeyType = keyType
|
||||
if keyType == keyTypeED25519 {
|
||||
return errors.Error(
|
||||
"ED25519 keys are not supported by browsers; " +
|
||||
"did you mean to use X25519 for key exchange?",
|
||||
)
|
||||
}
|
||||
|
||||
status.ValidKey = true
|
||||
status.KeyType = keyType
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCertificates processes certificate data and its private key. All
|
||||
// parameters are optional. On error, validateCertificates returns a partially
|
||||
// set object with field WarningValidation containing error description.
|
||||
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus {
|
||||
var data tlsConfigStatus
|
||||
|
||||
// check only public certificate separately from the key
|
||||
if certChain != "" {
|
||||
if verifyCertChain(&data, certChain, serverName) != nil {
|
||||
return data
|
||||
// parameters are optional. status must not be nil. The returned error is also
|
||||
// set in status.WarningValidation.
|
||||
func validateCertificates(
|
||||
status *tlsConfigStatus,
|
||||
certChain []byte,
|
||||
pkey []byte,
|
||||
serverName string,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
// Capitalize the warning for the UI. Assume that warnings are all
|
||||
// ASCII-only.
|
||||
//
|
||||
// TODO(a.garipov): Figure out a better way to do this. Perhaps a
|
||||
// custom string or error type.
|
||||
if w := status.WarningValidation; w != "" {
|
||||
status.WarningValidation = strings.ToUpper(w[:1]) + w[1:]
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// validate private key (right now the only validation possible is just parsing it)
|
||||
if pkey != "" {
|
||||
if validatePkey(&data, pkey) != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// if both are set, validate both in unison
|
||||
if pkey != "" && certChain != "" {
|
||||
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
|
||||
// Check only the public certificate separately from the key.
|
||||
if len(certChain) > 0 {
|
||||
err = validateCertChain(status, certChain, serverName)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err)
|
||||
return data
|
||||
return err
|
||||
}
|
||||
data.ValidPair = true
|
||||
}
|
||||
|
||||
return data
|
||||
// Validate the private key by parsing it.
|
||||
if len(pkey) > 0 {
|
||||
err = validatePKey(status, pkey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If both are set, validate together.
|
||||
if len(certChain) > 0 && len(pkey) > 0 {
|
||||
_, err = tls.X509KeyPair(certChain, pkey)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("certificate-key pair: %w", err)
|
||||
status.WarningValidation = err.Error()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
status.ValidPair = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Key types.
|
||||
@@ -693,52 +762,9 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration
|
||||
func (t *TLSMod) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", t.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate)
|
||||
}
|
||||
|
||||
// LoadSystemRootCAs tries to load root certificates from the operating system.
|
||||
// It returns nil in case nothing is found so that that Go.crypto will use it's
|
||||
// default algorithm to find system root CA list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/internal/issues/1311.
|
||||
func LoadSystemRootCAs() (roots *x509.CertPool) {
|
||||
// TODO(e.burkov): Use build tags instead.
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Directories with the system root certificates, which aren't supported
|
||||
// by Go.crypto.
|
||||
dirs := []string{
|
||||
// Entware.
|
||||
"/opt/etc/ssl/certs",
|
||||
}
|
||||
roots = x509.NewCertPool()
|
||||
for _, dir := range dirs {
|
||||
dirEnts, err := os.ReadDir(dir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
log.Error("opening directory: %q: %s", dir, err)
|
||||
}
|
||||
|
||||
var rootsAdded bool
|
||||
for _, de := range dirEnts {
|
||||
var certData []byte
|
||||
certData, err = os.ReadFile(filepath.Join(dir, de.Name()))
|
||||
if err == nil && roots.AppendCertsFromPEM(certData) {
|
||||
rootsAdded = true
|
||||
}
|
||||
}
|
||||
|
||||
if rootsAdded {
|
||||
return roots
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
func (m *tlsManager) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
CertificateChain = `-----BEGIN CERTIFICATE-----
|
||||
var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||
@@ -21,8 +20,9 @@ eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
|
||||
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||
-----END CERTIFICATE-----`
|
||||
PrivateKey = `-----BEGIN PRIVATE KEY-----
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||
@@ -37,36 +37,43 @@ An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
|
||||
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||
kXS9jgARhhiWXJrk
|
||||
-----END PRIVATE KEY-----`
|
||||
)
|
||||
-----END PRIVATE KEY-----`)
|
||||
|
||||
func TestValidateCertificates(t *testing.T) {
|
||||
t.Run("bad_certificate", func(t *testing.T) {
|
||||
data := validateCertificates("bad cert", "", "")
|
||||
assert.NotEmpty(t, data.WarningValidation)
|
||||
assert.False(t, data.ValidCert)
|
||||
assert.False(t, data.ValidChain)
|
||||
status := &tlsConfigStatus{}
|
||||
err := validateCertificates(status, []byte("bad cert"), nil, "")
|
||||
assert.Error(t, err)
|
||||
assert.NotEmpty(t, status.WarningValidation)
|
||||
assert.False(t, status.ValidCert)
|
||||
assert.False(t, status.ValidChain)
|
||||
})
|
||||
|
||||
t.Run("bad_private_key", func(t *testing.T) {
|
||||
data := validateCertificates("", "bad priv key", "")
|
||||
assert.NotEmpty(t, data.WarningValidation)
|
||||
assert.False(t, data.ValidKey)
|
||||
status := &tlsConfigStatus{}
|
||||
err := validateCertificates(status, nil, []byte("bad priv key"), "")
|
||||
assert.Error(t, err)
|
||||
assert.NotEmpty(t, status.WarningValidation)
|
||||
assert.False(t, status.ValidKey)
|
||||
})
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
data := validateCertificates(CertificateChain, PrivateKey, "")
|
||||
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
|
||||
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
|
||||
assert.NotEmpty(t, data.WarningValidation)
|
||||
assert.True(t, data.ValidCert)
|
||||
assert.False(t, data.ValidChain)
|
||||
assert.True(t, data.ValidKey)
|
||||
assert.Equal(t, "RSA", data.KeyType)
|
||||
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject)
|
||||
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer)
|
||||
assert.Equal(t, notBefore, data.NotBefore)
|
||||
assert.Equal(t, notAfter, data.NotAfter)
|
||||
assert.True(t, data.ValidPair)
|
||||
status := &tlsConfigStatus{}
|
||||
err := validateCertificates(status, testCertChainData, testPrivateKeyData, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC)
|
||||
notAfter := time.Date(2046, 7, 14, 9, 24, 23, 0, time.UTC)
|
||||
|
||||
assert.NotEmpty(t, status.WarningValidation)
|
||||
assert.True(t, status.ValidCert)
|
||||
assert.False(t, status.ValidChain)
|
||||
assert.True(t, status.ValidKey)
|
||||
assert.Equal(t, "RSA", status.KeyType)
|
||||
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Subject)
|
||||
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Issuer)
|
||||
assert.Equal(t, notBefore, status.NotBefore)
|
||||
assert.Equal(t, notAfter, status.NotAfter)
|
||||
assert.True(t, status.ValidPair)
|
||||
})
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,7 +39,7 @@ type webConfig struct {
|
||||
clientFS fs.FS
|
||||
clientBetaFS fs.FS
|
||||
|
||||
BindHost net.IP
|
||||
BindHost netip.Addr
|
||||
BindPort int
|
||||
BetaBindPort int
|
||||
PortHTTPS int
|
||||
@@ -137,8 +137,11 @@ func newWeb(conf *webConfig) (w *Web) {
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func webCheckPortAvailable(port int) (ok bool) {
|
||||
return Context.web.httpsServer.server != nil ||
|
||||
aghnet.CheckPort("tcp", config.BindHost, port) == nil
|
||||
if Context.web.httpsServer.server != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
|
||||
}
|
||||
|
||||
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
|
||||
Reference in New Issue
Block a user