diff --git a/CHANGELOG.md b/CHANGELOG.md index cae2cdfc..a8383a3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,8 @@ See also the [v0.107.58 GitHub milestone][ms-v0.107.58]. ### Fixed +- Validation process for the HTTPS port on the *Encryption Settings* page. + - Clearing the DNS cache on the *DNS settings* page now includes both global cache and custom client cache. - Invalid ICMPv6 Router Advertisement messages ([#7547]). diff --git a/internal/home/config.go b/internal/home/config.go index 098ea0a9..23cdd7fe 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -568,7 +568,7 @@ func parseConfig() (err error) { } // Do not wrap the error because it's informative enough as is. - return setContextTLSCipherIDs() + return validateTLSCipherIDs(config.TLS.OverrideTLSCiphers) } // validateConfig returns error if the configuration is invalid. @@ -721,21 +721,15 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) { return nil } -// setContextTLSCipherIDs sets the TLS cipher suite IDs to use. -func setContextTLSCipherIDs() (err error) { - if len(config.TLS.OverrideTLSCiphers) == 0 { - log.Info("tls: using default ciphers") - - globalContext.tlsCipherIDs = aghtls.SaferCipherSuites() - +// validateTLSCipherIDs validates the custom TLS cipher suite IDs. +func validateTLSCipherIDs(cipherIDs []string) (err error) { + if len(cipherIDs) == 0 { return nil } - log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers) - - globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers) + _, err = aghtls.ParseCiphers(cipherIDs) if err != nil { - return fmt.Errorf("parsing override ciphers: %w", err) + return fmt.Errorf("override_tls_ciphers: %w", err) } return nil diff --git a/internal/home/dns.go b/internal/home/dns.go index 4cfc63f8..766136fc 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -38,6 +38,8 @@ const ( ) // Called by other modules when configuration is changed +// +// TODO(s.chzhen): Remove this after refactoring. func onConfigModified() { err := config.write(globalContext.tls) if err != nil { @@ -120,14 +122,15 @@ func initDNS( anonymizer, httpRegister, tlsConf, + tlsMgr, baseLogger, ) } // initDNSServer initializes the [context.dnsServer]. To only use the internal -// proxy, none of the arguments are required, but tlsConf and l still must not -// be nil, in other cases all the arguments also must not be nil. It also must -// not be called unless [config] and [globalContext] are initialized. +// proxy, none of the arguments are required, but tlsConf, tlsMgr and l still +// must not be nil, in other cases all the arguments also must not be nil. It +// also must not be called unless [config] and [globalContext] are initialized. // // TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter. func initDNSServer( @@ -138,6 +141,7 @@ func initDNSServer( anonymizer *aghnet.IPMut, httpReg aghhttp.RegisterFunc, tlsConf *tlsConfigSettings, + tlsMgr *tlsManager, l *slog.Logger, ) (err error) { globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ @@ -166,6 +170,7 @@ func initDNSServer( &config.DNS, config.Clients.Sources, tlsConf, + tlsMgr, httpReg, globalContext.clients.storage, ) @@ -236,11 +241,12 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) { } // newServerConfig converts values from the configuration file into the internal -// DNS server configuration. All arguments must not be nil. +// DNS server configuration. All arguments must not be nil, except for httpReg. func newServerConfig( dnsConf *dnsConfig, clientSrcConf *clientSourcesConfig, tlsConf *tlsConfigSettings, + tlsMgr *tlsManager, httpReg aghhttp.RegisterFunc, clientsContainer dnsforward.ClientsContainer, ) (newConf *dnsforward.ServerConfig, err error) { @@ -256,7 +262,7 @@ func newServerConfig( TLSConfig: newDNSTLSConfig(tlsConf, hosts), TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH, UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout), - TLSv12Roots: globalContext.tlsRoots, + TLSv12Roots: tlsMgr.rootCerts, ConfigModified: onConfigModified, HTTPRegister: httpReg, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers, diff --git a/internal/home/home.go b/internal/home/home.go index f774a06a..0eef1826 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -3,7 +3,6 @@ package home import ( "context" - "crypto/x509" "fmt" "io/fs" "log/slog" @@ -22,7 +21,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -81,10 +79,6 @@ type homeContext struct { workDir string // Location of our directory, used to protect against CWD being somewhere else pidFileName string // PID file name. Empty if no PID file was created. controlLock sync.Mutex - tlsRoots *x509.CertPool // list of root CAs for TLSv1.2 - - // tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use. - tlsCipherIDs []uint16 // firstRun, if true, tells AdGuard Home to only start the web interface // service, and only serve the first-run APIs. @@ -142,7 +136,6 @@ func Main(clientBuildFS fs.FS) { func setupContext(opts options) (err error) { globalContext.firstRun = detectFirstRun() - globalContext.tlsRoots = aghtls.SystemRootCAs() globalContext.mux = http.NewServeMux() if !opts.noEtcHosts { @@ -274,18 +267,13 @@ func setupOpts(opts options) (err error) { return nil } -// initContextClients initializes Context clients and related fields. +// initContextClients initializes Context clients and related fields. All +// arguments must not be nil. func initContextClients( ctx context.Context, logger *slog.Logger, sigHdlr *signalHandler, ) (err error) { - err = setupDNSFilteringConf(ctx, logger, config.Filtering) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - //lint:ignore SA1019 Migration is not over. config.DHCP.WorkDir = globalContext.workDir config.DHCP.DataDir = globalContext.getDataDir() @@ -358,11 +346,13 @@ func setupBindOpts(opts options) (err error) { return nil } -// setupDNSFilteringConf sets up DNS filtering configuration settings. +// setupDNSFilteringConf sets up DNS filtering configuration settings. All +// arguments must not be nil. func setupDNSFilteringConf( ctx context.Context, baseLogger *slog.Logger, conf *filtering.Config, + tlsMgr *tlsManager, ) (err error) { const ( dnsTimeout = 3 * time.Second @@ -388,7 +378,7 @@ func setupDNSFilteringConf( conf.Filters = slices.Clone(config.Filters) conf.WhitelistFilters = slices.Clone(config.WhitelistFilters) conf.UserRules = slices.Clone(config.UserRules) - conf.HTTPClient = httpClient() + conf.HTTPClient = httpClient(tlsMgr) cacheTime := time.Duration(conf.CacheTime) * time.Minute @@ -630,6 +620,23 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH err = initContextClients(ctx, slogLogger, sigHdlr) fatalOnError(err) + tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") + tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: tlsMgrLogger, + configModified: onConfigModified, + tlsSettings: config.TLS, + servePlainDNS: config.DNS.ServePlainDNS, + }) + if err != nil { + tlsMgrLogger.ErrorContext(ctx, "initializing", slogutil.KeyError, err) + onConfigModified() + } + + globalContext.tls = tlsMgr + + err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr) + fatalOnError(err) + err = setupOpts(opts) fatalOnError(err) @@ -642,7 +649,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH // TODO(e.burkov): This could be made earlier, probably as the option's // effect. - cmdlineUpdate(ctx, slogLogger, opts, upd) + cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr) if !globalContext.firstRun { // Save the updated config. @@ -664,19 +671,14 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH globalContext.auth, err = initUsers() fatalOnError(err) - tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") - tlsMgr, err := newTLSManager(ctx, tlsMgrLogger, config.TLS, config.DNS.ServePlainDNS) - if err != nil { - log.Error("initializing tls: %s", err) - onConfigModified() - } - - globalContext.tls = tlsMgr - sigHdlr.addTLSManager(tlsMgr) - - globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL) + web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL) fatalOnError(err) + globalContext.web = web + + tlsMgr.setWebAPI(web) + sigHdlr.addTLSManager(tlsMgr) + statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config) fatalOnError(err) @@ -706,7 +708,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir) } - globalContext.web.start(ctx) + web.start(ctx) // Wait for other goroutines to complete their job. <-done @@ -1058,8 +1060,15 @@ type jsonError struct { Message string `json:"message"` } -// cmdlineUpdate updates current application and exits. l must not be nil. -func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updater.Updater) { +// cmdlineUpdate updates current application and exits. l and tlsMgr must not +// be nil. +func cmdlineUpdate( + ctx context.Context, + l *slog.Logger, + opts options, + upd *updater.Updater, + tlsMgr *tlsManager, +) { if !opts.performUpdate { return } @@ -1069,7 +1078,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat // // TODO(e.burkov): We could probably initialize the internal resolver // separately. - err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l) + err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, tlsMgr, l) fatalOnError(err) l.InfoContext(ctx, "performing update via cli") diff --git a/internal/home/httpclient.go b/internal/home/httpclient.go index 5e39c864..7c6f2ae1 100644 --- a/internal/home/httpclient.go +++ b/internal/home/httpclient.go @@ -10,10 +10,10 @@ import ( // httpClient returns a new HTTP client that uses the AdGuard Home's own DNS // server for resolving hostnames. The resulting client should not be used -// until [Context.dnsServer] is initialized. +// until [Context.dnsServer] is initialized. tlsMgr must not be nil. // // TODO(a.garipov, e.burkov): This is rather messy. Refactor. -func httpClient() (c *http.Client) { +func httpClient(tlsMgr *tlsManager) (c *http.Client) { // Do not use Context.dnsServer.DialContext directly in the struct literal // below, since Context.dnsServer may be nil when this function is called. dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) { @@ -27,8 +27,8 @@ func httpClient() (c *http.Client) { DialContext: dialContext, Proxy: httpProxy, TLSClientConfig: &tls.Config{ - RootCAs: globalContext.tlsRoots, - CipherSuites: globalContext.tlsCipherIDs, + RootCAs: tlsMgr.rootCerts, + CipherSuites: tlsMgr.customCipherIDs, MinVersion: tls.VersionTLS12, }, }, diff --git a/internal/home/tls.go b/internal/home/tls.go index 778c674b..e012a309 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -14,6 +14,7 @@ import ( "fmt" "log/slog" "net/http" + "net/netip" "os" "strings" "sync" @@ -21,6 +22,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" @@ -41,6 +43,22 @@ type tlsManager struct { // certLastMod is the last modification time of the certificate file. certLastMod time.Time + // rootCerts is a pool of root CAs for TLSv1.2. + rootCerts *x509.CertPool + + // web is the web UI and API server. It must not be nil. + // + // TODO(s.chzhen): Temporary cyclic dependency due to ongoing refactoring. + // Resolve it. + web *webAPI + + // configModified is called when the TLS configuration is changed via an + // HTTP request. + configModified func() + + // customCipherIDs are the ID of the cipher suites that AdGuard Home must use. + customCipherIDs []uint16 + confLock sync.Mutex conf tlsConfigSettings @@ -48,21 +66,50 @@ type tlsManager struct { servePlainDNS bool } +// tlsManagerConfig contains the settings for initializing the TLS manager. +type tlsManagerConfig struct { + // logger is used for logging the operation of the TLS Manager. It must not + // be nil. + logger *slog.Logger + + // configModified is called when the TLS configuration is changed via an + // HTTP request. It must not be nil. + configModified func() + + // tlsSettings contains the TLS configuration settings. + tlsSettings tlsConfigSettings + + // servePlainDNS defines if plain DNS is allowed for incoming requests. + servePlainDNS bool +} + // newTLSManager initializes the manager of TLS configuration. m is always // non-nil while any returned error indicates that the TLS configuration isn't -// valid. Thus TLS may be initialized later, e.g. via the web UI. logger must -// not be nil. -func newTLSManager( - ctx context.Context, - logger *slog.Logger, - conf tlsConfigSettings, - servePlainDNS bool, -) (m *tlsManager, err error) { +// valid. Thus TLS may be initialized later, e.g. via the web UI. conf must +// not be nil. Note that [tlsManager.web] must be initialized later on by using +// [tlsManager.setWebAPI]. +func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) { m = &tlsManager{ - logger: logger, - status: &tlsConfigStatus{}, - conf: conf, - servePlainDNS: servePlainDNS, + logger: conf.logger, + configModified: conf.configModified, + status: &tlsConfigStatus{}, + conf: conf.tlsSettings, + servePlainDNS: conf.servePlainDNS, + } + + m.rootCerts = aghtls.SystemRootCAs() + + if len(conf.tlsSettings.OverrideTLSCiphers) > 0 { + m.customCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers) + if err != nil { + // Should not happen because upstreams are already validated. See + // [validateTLSCipherIDs]. + panic(err) + } + + m.logger.InfoContext(ctx, "overriding ciphers", "ciphers", config.TLS.OverrideTLSCiphers) + } else { + m.logger.InfoContext(ctx, "using default ciphers") } if m.conf.Enabled { @@ -79,6 +126,15 @@ func newTLSManager( return m, nil } +// setWebAPI stores the provided web API. It must be called before +// [tlsManager.start], [tlsManager.reload], [tlsManager.handleTLSConfigure], or +// [tlsManager.validateTLSSettings]. +// +// TODO(s.chzhen): Remove it once cyclic dependency is resolved. +func (m *tlsManager) setWebAPI(webAPI *webAPI) { + m.web = webAPI +} + // load reloads the TLS configuration from files or data from the config file. func (m *tlsManager) load(ctx context.Context) (err error) { err = m.loadTLSConf(ctx, &m.conf, m.status) @@ -126,7 +182,7 @@ func (m *tlsManager) start(_ context.Context) { // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - globalContext.web.tlsConfigChanged(context.Background(), tlsConf) + m.web.tlsConfigChanged(context.Background(), tlsConf) } // reload updates the configuration and restarts the TLS manager. @@ -178,7 +234,7 @@ func (m *tlsManager) reload(ctx context.Context) { // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - globalContext.web.tlsConfigChanged(context.Background(), tlsConf) + m.web.tlsConfigChanged(context.Background(), tlsConf) } // reconfigureDNSServer updates the DNS server configuration using the stored @@ -191,6 +247,7 @@ func (m *tlsManager) reconfigureDNSServer() (err error) { &config.DNS, config.Clients.Sources, tlsConf, + m, httpRegister, globalContext.clients.storage, ) @@ -368,6 +425,8 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { // handleTLSValidate is the handler for the POST /control/tls/validate HTTP API. func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + setts, err := unmarshalTLS(r) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) @@ -379,7 +438,9 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { setts.PrivateKey = m.conf.PrivateKey } - if err = validateTLSSettings(setts); err != nil { + if err = m.validateTLSSettings(setts); err != nil { + m.logger.InfoContext(ctx, "validating tls settings", slogutil.KeyError, err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return @@ -388,7 +449,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { // Skip the error check, since we are only interested in the value of // status.WarningValidation. status := &tlsConfigStatus{} - _ = m.loadTLSConf(r.Context(), &setts.tlsConfigSettings, status) + _ = m.loadTLSConf(ctx, &setts.tlsConfigSettings, status) resp := tlsConfig{ tlsConfigSettingsExt: setts, tlsConfigStatus: status, @@ -458,7 +519,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) req.PrivateKey = m.conf.PrivateKey } - if err = validateTLSSettings(req); err != nil { + if err = m.validateTLSSettings(req); err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return @@ -489,7 +550,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) }() } - onConfigModified() + m.configModified() err = m.reconfigureDNSServer() if err != nil { @@ -516,36 +577,54 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) // same reason. if restartHTTPS { go func() { - globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings) + m.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings) }() } } // validateTLSSettings returns error if the setts are not valid. -func validateTLSSettings(setts tlsConfigSettingsExt) (err error) { - if setts.Enabled { - err = validatePorts( - tcpPort(config.HTTPConfig.Address.Port()), - tcpPort(setts.PortHTTPS), - tcpPort(setts.PortDNSOverTLS), - tcpPort(setts.PortDNSCrypt), - udpPort(config.DNS.Port), - udpPort(setts.PortDNSOverQUIC), - ) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err +func (m *tlsManager) validateTLSSettings(setts tlsConfigSettingsExt) (err error) { + if !setts.Enabled { + if setts.ServePlainDNS == aghalg.NBFalse { + // TODO(a.garipov): Support full disabling of all DNS. + return errors.Error("plain DNS is required in case encryption protocols are disabled") } - } else if setts.ServePlainDNS == aghalg.NBFalse { - // TODO(a.garipov): Support full disabling of all DNS. - return errors.Error("plain DNS is required in case encryption protocols are disabled") + + return nil } - if !webCheckPortAvailable(setts.PortHTTPS) { - return fmt.Errorf("port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS) + var ( + tlsConf tlsConfigSettings + webAPIAddr netip.Addr + webAPIPort uint16 + plainDNSPort uint16 + ) + + func() { + config.Lock() + defer config.Unlock() + + tlsConf = config.TLS + webAPIAddr = config.HTTPConfig.Address.Addr() + webAPIPort = config.HTTPConfig.Address.Port() + plainDNSPort = config.DNS.Port + }() + + err = validatePorts( + tcpPort(webAPIPort), + tcpPort(setts.PortHTTPS), + tcpPort(setts.PortDNSOverTLS), + tcpPort(setts.PortDNSCrypt), + udpPort(plainDNSPort), + udpPort(setts.PortDNSOverQUIC), + ) + if err != nil { + // Don't wrap the error because it's informative enough as is. + return err } - return nil + // Don't wrap the error because it's informative enough as is. + return m.checkPortAvailability(tlsConf, setts.tlsConfigSettings, webAPIAddr) } // validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home @@ -557,10 +636,11 @@ func validatePorts( tcpPorts := aghalg.UniqChecker[tcpPort]{} addPorts( tcpPorts, - tcpPort(bindPort), - tcpPort(dohPort), - tcpPort(dotPort), - tcpPort(dnscryptTCPPort), + bindPort, + dohPort, + dotPort, + dnscryptTCPPort, + tcpPort(dnsPort), ) err = tcpPorts.Validate() @@ -569,7 +649,7 @@ func validatePorts( } udpPorts := aghalg.UniqChecker[udpPort]{} - addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort)) + addPorts(udpPorts, dnsPort, doqPort) err = udpPorts.Validate() if err != nil { @@ -604,7 +684,7 @@ func (m *tlsManager) validateCertChain( opts := x509.VerifyOptions{ DNSName: srvName, - Roots: globalContext.tlsRoots, + Roots: m.rootCerts, Intermediates: pool, } _, err = main.Verify(opts) @@ -615,6 +695,67 @@ func (m *tlsManager) validateCertChain( return nil } +// checkPortAvailability checks [tlsConfigSettings.PortHTTPS], +// [tlsConfigSettings.PortDNSOverTLS], and [tlsConfigSettings.PortDNSOverQUIC] +// are available for use. It checks the current configuration and, if needed, +// attempts to bind to the port. The function returns human-readable error +// messages for the frontend. This is best-effort check to prevent an "address +// already in use" error. +// +// TODO(a.garipov): Adapt for HTTP/3. +func (m *tlsManager) checkPortAvailability( + currConf tlsConfigSettings, + newConf tlsConfigSettings, + addr netip.Addr, +) (err error) { + const ( + networkTCP = "tcp" + networkUDP = "udp" + + protoHTTPS = "HTTPS" + protoDoT = "DNS-over-TLS" + protoDoQ = "DNS-over-QUIC" + ) + + needBindingCheck := []struct { + network string + proto string + currPort uint16 + newPort uint16 + }{{ + network: networkTCP, + proto: protoHTTPS, + currPort: currConf.PortHTTPS, + newPort: newConf.PortHTTPS, + }, { + network: networkTCP, + proto: protoDoT, + currPort: currConf.PortDNSOverTLS, + newPort: newConf.PortDNSOverTLS, + }, { + network: networkUDP, + proto: protoDoQ, + currPort: currConf.PortDNSOverQUIC, + newPort: newConf.PortDNSOverQUIC, + }} + + var errs []error + for _, v := range needBindingCheck { + port := v.newPort + if v.currPort == port { + continue + } + + addrPort := netip.AddrPortFrom(addr, port) + err = aghnet.CheckPort(v.network, addrPort) + if err != nil { + errs = append(errs, fmt.Errorf("port %d for %s is not available", port, v.proto)) + } + } + + return errors.Join(errs...) +} + // errNoIPInCert is the error that is returned from [tlsManager.parseCertChain] // if the leaf certificate doesn't contain IPs. const errNoIPInCert errors.Error = `certificates has no IP addresses; ` + @@ -718,27 +859,12 @@ func (m *tlsManager) validateCertificates( ) (err error) { // Check only the public certificate separately from the key. if len(certChain) > 0 { - var certs []*x509.Certificate - certs, status.ValidCert, err = m.parseCertChain(ctx, certChain) - if !status.ValidCert { + var ok bool + ok, err = m.validateCertificate(ctx, status, certChain, serverName) + if !ok { // Don't wrap the error, since it's informative enough as is. return err } - - mainCert := certs[0] - status.Subject = mainCert.Subject.String() - status.Issuer = mainCert.Issuer.String() - status.NotAfter = mainCert.NotAfter - status.NotBefore = mainCert.NotBefore - status.DNSNames = mainCert.DNSNames - - if chainErr := m.validateCertChain(ctx, certs, serverName); chainErr != nil { - // Let self-signed certs through and don't return this error to set - // its message into the status.WarningValidation afterwards. - err = chainErr - } else { - status.ValidChain = true - } } // Validate the private key by parsing it. @@ -766,6 +892,41 @@ func (m *tlsManager) validateCertificates( return err } +// validateCertificate processes certificate data. status must not be nil, as +// it is used to accumulate the validation results. Other parameters are +// optional. If ok is true, the returned error, if any, is not critical. +func (m *tlsManager) validateCertificate( + ctx context.Context, + status *tlsConfigStatus, + certChain []byte, + serverName string, +) (ok bool, err error) { + var certs []*x509.Certificate + certs, status.ValidCert, err = m.parseCertChain(ctx, certChain) + if !status.ValidCert { + // Don't wrap the error, since it's informative enough as is. + return false, err + } + + mainCert := certs[0] + status.Subject = mainCert.Subject.String() + status.Issuer = mainCert.Issuer.String() + status.NotAfter = mainCert.NotAfter + status.NotBefore = mainCert.NotBefore + status.DNSNames = mainCert.DNSNames + + err = m.validateCertChain(ctx, certs, serverName) + if err != nil { + // Let self-signed certs through and don't return this error to set + // its message into the status.WarningValidation afterwards. + return true, err + } + + status.ValidChain = true + + return true, nil +} + // Key types. const ( keyTypeECDSA = "ECDSA" @@ -828,17 +989,18 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) { } } - if data.PrivateKey != "" { - var key []byte - key, err = base64.StdEncoding.DecodeString(data.PrivateKey) - if err != nil { - return data, fmt.Errorf("failed to base64-decode private key: %w", err) - } + if data.PrivateKey == "" { + return data, nil + } - data.PrivateKey = string(key) - if data.PrivateKeyPath != "" { - return data, fmt.Errorf("private key data and file can't be set together") - } + key, err := base64.StdEncoding.DecodeString(data.PrivateKey) + if err != nil { + return data, fmt.Errorf("failed to base64-decode private key: %w", err) + } + + data.PrivateKey = string(key) + if data.PrivateKeyPath != "" { + return data, fmt.Errorf("private key data and file can't be set together") } return data, nil diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index e67393b4..e7e539d6 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/require" ) +// TODO(s.chzhen): Consider moving to testdata. var testCertChainData = []byte(`-----BEGIN CERTIFICATE----- MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 @@ -66,7 +67,11 @@ func TestValidateCertificates(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) logger := slogutil.NewDiscardLogger() - m, err := newTLSManager(ctx, logger, tlsConfigSettings{}, false) + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + servePlainDNS: false, + }) require.NoError(t, err) t.Run("bad_certificate", func(t *testing.T) { @@ -112,7 +117,6 @@ func TestValidateCertificates(t *testing.T) { // - [homeContext.clients.storage] // - [homeContext.dnsServer] // - [homeContext.mux] -// - [homeContext.web] // // TODO(s.chzhen): Remove this once the TLS manager no longer accesses global // variables. Make tests that use this helper concurrent. @@ -123,14 +127,12 @@ func storeGlobals(tb testing.TB) { storage := globalContext.clients.storage dnsServer := globalContext.dnsServer mux := globalContext.mux - web := globalContext.web tb.Cleanup(func() { config = prevConfig globalContext.clients.storage = storage globalContext.dnsServer = dnsServer globalContext.mux = mux - globalContext.web = web }) } @@ -221,9 +223,6 @@ func TestTLSManager_Reload(t *testing.T) { globalContext.mux = http.NewServeMux() - globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) - require.NoError(t, err) - const ( snBefore int64 = 1 snAfter int64 = 2 @@ -236,15 +235,25 @@ func TestTLSManager_Reload(t *testing.T) { certDER, key := newCertAndKey(t, snBefore) writeCertAndKey(t, certDER, certPath, key, keyPath) - m, err := newTLSManager(ctx, logger, tlsConfigSettings{ - Enabled: true, - TLSConfig: dnsforward.TLSConfig{ - CertificatePath: certPath, - PrivateKeyPath: keyPath, + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificatePath: certPath, + PrivateKeyPath: keyPath, + }, }, - }, false) + servePlainDNS: false, + }) require.NoError(t, err) + web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + m.setWebAPI(web) + conf := &tlsConfigSettings{} m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, snBefore) @@ -265,13 +274,18 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) { err error ) - m, err := newTLSManager(ctx, logger, tlsConfigSettings{ - Enabled: true, - TLSConfig: dnsforward.TLSConfig{ - CertificateChain: string(testCertChainData), - PrivateKey: string(testPrivateKeyData), + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, }, - }, false) + servePlainDNS: false, + }) require.NoError(t, err) w := httptest.NewRecorder() @@ -291,26 +305,42 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) { func TestValidateTLSSettings(t *testing.T) { storeGlobals(t) + globalContext.mux = http.NewServeMux() + var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) - ln, err := net.Listen("tcp", ":0") + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + servePlainDNS: false, + }) require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, ln.Close) - - addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr()) - - busyPort := addr.Port - - globalContext.mux = http.NewServeMux() - - globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) + m.setWebAPI(web) + + tcpLn, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, tcpLn.Close) + + tcpAddr := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpLn.Addr()) + busyTCPPort := tcpAddr.Port + + udpLn, err := net.ListenPacket("udp", ":0") + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, udpLn.Close) + + udpAddr := testutil.RequireTypeAssert[*net.UDPAddr](t, udpLn.LocalAddr()) + busyUDPPort := udpAddr.Port + testCases := []struct { setts tlsConfigSettingsExt name string @@ -329,11 +359,29 @@ func TestValidateTLSSettings(t *testing.T) { setts: tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, - PortHTTPS: uint16(busyPort), + PortHTTPS: uint16(busyTCPPort), }, }, - name: "busy_port", - wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort), + name: "busy_https_port", + wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort), + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortDNSOverTLS: uint16(busyTCPPort), + }, + }, + name: "busy_dot_port", + wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort), + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortDNSOverQUIC: uint16(busyUDPPort), + }, + }, + name: "busy_doq_port", + wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort), }, { setts: tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ @@ -348,7 +396,7 @@ func TestValidateTLSSettings(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err = validateTLSSettings(tc.setts) + err = m.validateTLSSettings(tc.setts) testutil.AssertErrorMsg(t, tc.wantErr, err) }) } @@ -357,26 +405,33 @@ func TestValidateTLSSettings(t *testing.T) { func TestTLSManager_HandleTLSValidate(t *testing.T) { storeGlobals(t) + globalContext.mux = http.NewServeMux() + var ( logger = slogutil.NewDiscardLogger() ctx = testutil.ContextWithTimeout(t, testTimeout) err error ) - globalContext.mux = http.NewServeMux() - - globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) - require.NoError(t, err) - - m, err := newTLSManager(ctx, logger, tlsConfigSettings{ - Enabled: true, - TLSConfig: dnsforward.TLSConfig{ - CertificateChain: string(testCertChainData), - PrivateKey: string(testPrivateKeyData), + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, }, - }, false) + servePlainDNS: false, + }) require.NoError(t, err) + web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + m.setWebAPI(web) + setts := &tlsConfigSettingsExt{ tlsConfigSettings: tlsConfigSettings{ Enabled: true, @@ -438,9 +493,6 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { globalContext.mux = http.NewServeMux() - globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) - require.NoError(t, err) - config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} config.DNS.Port = 0 @@ -455,15 +507,25 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { writeCertAndKey(t, certDER, certPath, key, keyPath) // Initialize the TLS manager and assert its configuration. - m, err := newTLSManager(ctx, logger, tlsConfigSettings{ - Enabled: true, - TLSConfig: dnsforward.TLSConfig{ - CertificatePath: certPath, - PrivateKeyPath: keyPath, + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificatePath: certPath, + PrivateKeyPath: keyPath, + }, }, - }, true) + servePlainDNS: true, + }) require.NoError(t, err) + web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + m.setWebAPI(web) + conf := &tlsConfigSettings{} m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, wantSerialNumber) @@ -509,10 +571,10 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { // // TODO(s.chzhen): Remove when [httpsServer.cond] is removed. assert.Eventually(t, func() bool { - globalContext.web.httpsServer.condLock.Lock() - defer globalContext.web.httpsServer.condLock.Unlock() + web.httpsServer.condLock.Lock() + defer web.httpsServer.condLock.Unlock() - cert = globalContext.web.httpsServer.cert + cert = web.httpsServer.cert if cert.Leaf == nil { return false } diff --git a/internal/home/web.go b/internal/home/web.go index e9fe6dab..9be52850 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -12,10 +12,8 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil/httputil" @@ -158,27 +156,6 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) { return w } -// webCheckPortAvailable checks if port, which is considered an HTTPS port, is -// available, unless the HTTPS server isn't active. -// -// TODO(a.garipov): Adapt for HTTP/3. -func webCheckPortAvailable(port uint16) (ok bool) { - if globalContext.web.httpsServer.server != nil { - return true - } - - addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port) - - err := aghnet.CheckPort("tcp", addrPort) - if err != nil { - log.Info("web: warning: checking https port: %s", err) - - return false - } - - return true -} - // tlsConfigChanged updates the TLS configuration and restarts the HTTPS server // if necessary. func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { @@ -329,8 +306,8 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) { Handler: hdlr, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{web.httpsServer.cert}, - RootCAs: globalContext.tlsRoots, - CipherSuites: globalContext.tlsCipherIDs, + RootCAs: web.tlsManager.rootCerts, + CipherSuites: web.tlsManager.customCipherIDs, MinVersion: tls.VersionTLS12, }, ReadTimeout: web.conf.ReadTimeout, @@ -363,8 +340,8 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) { Addr: address, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{web.httpsServer.cert}, - RootCAs: globalContext.tlsRoots, - CipherSuites: globalContext.tlsCipherIDs, + RootCAs: web.tlsManager.rootCerts, + CipherSuites: web.tlsManager.customCipherIDs, MinVersion: tls.VersionTLS12, }, Handler: withMiddlewares(globalContext.mux, limitRequestBody),