diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 6357d681..0c3531e4 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -22,6 +22,7 @@ import ( "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/ameshkov/dnscrypt/v2" + "golang.org/x/exp/slices" ) // BlockingMode is an enum of all allowed blocking modes. @@ -145,7 +146,8 @@ type FilteringConfig struct { IpsetListFileName string `yaml:"ipset_file"` } -// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS +// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, DNS-over-TLS, +// and DNS-over-QUIC. type TLSConfig struct { cert tls.Certificate @@ -184,6 +186,11 @@ type TLSConfig struct { hasIPAddrs bool } +// CertDataClone returns a deep copy of certificate data. +func (c TLSConfig) CertDataClone() (certData, keyData []byte) { + return slices.Clone(c.CertificateChainData), slices.Clone(c.PrivateKeyData) +} + // DNSCryptConfig is the DNSCrypt server configuration struct. type DNSCryptConfig struct { ResolverCert *dnscrypt.Cert diff --git a/internal/home/config.go b/internal/home/config.go index c7198d93..7a03eef4 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/timeutil" "github.com/google/renameio/maybe" + "golang.org/x/exp/slices" yaml "gopkg.in/yaml.v3" ) @@ -113,8 +114,8 @@ type configuration struct { // An active session is automatically refreshed once a day. WebSessionTTLHours uint32 `yaml:"web_session_ttl"` - DNS dnsConfig `yaml:"dns"` - TLS tlsConfigSettings `yaml:"tls"` + DNS dnsConfig `yaml:"dns"` + TLS tlsConfiguration `yaml:"tls"` // Filters reflects the filters from [filtering.Config]. It's cloned to the // config used in the filtering module at the startup. Afterwards it's @@ -199,7 +200,8 @@ type dnsConfig struct { UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"` } -type tlsConfigSettings struct { +// tlsConfiguration is the on-disk TLS configuration. +type tlsConfiguration struct { Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect @@ -223,6 +225,22 @@ type tlsConfigSettings struct { dnsforward.TLSConfig `yaml:",inline" json:",inline"` } +// partialClone returns a clone of c with all top-level fields of c and all +// exported and YAML-encoded fields of c.TLSConfig cloned. +// +// TODO(a.garipov): This is better than races, but still not good enough. +func (c *tlsConfiguration) partialClone() (cloned *tlsConfiguration) { + if c == nil { + return nil + } + + v := *c + cloned = &v + cloned.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers) + + return cloned +} + // config is the global configuration structure. // // TODO(a.garipov, e.burkov): This global is awful and must be removed. @@ -273,7 +291,7 @@ var config = &configuration{ UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, UsePrivateRDNS: true, }, - TLS: tlsConfigSettings{ + TLS: tlsConfiguration{ PortHTTPS: defaultPortHTTPS, PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy PortDNSOverQUIC: defaultPortQUIC, @@ -442,7 +460,7 @@ func (c *configuration) write() (err error) { } if Context.tls != nil { - tlsConf := tlsConfigSettings{} + tlsConf := tlsConfiguration{} Context.tls.WriteDiskConfig(&tlsConf) config.TLS = tlsConf } diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index ef4f0659..2232986d 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -154,7 +154,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) { return nil } - tlsConf := &tlsConfigSettings{} + tlsConf := &tlsConfiguration{} Context.tls.WriteDiskConfig(tlsConf) canUpdate := true @@ -172,7 +172,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) { // tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration // indicates that privileged ports are used. -func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) { +func tlsConfUsesPrivilegedPorts(c *tlsConfiguration) (ok bool) { return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024) } diff --git a/internal/home/dns.go b/internal/home/dns.go index 1980b252..27ae8319 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -205,7 +205,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { OnDNSRequest: onDNSRequest, } - tlsConf := tlsConfigSettings{} + tlsConf := tlsConfiguration{} Context.tls.WriteDiskConfig(&tlsConf) if tlsConf.Enabled { newConf.TLSConfig = tlsConf.TLSConfig @@ -250,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { return newConf, nil } -func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) { +func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfiguration) (dnscc dnsforward.DNSCryptConfig, err error) { if tlsConf.DNSCryptConfigFile == "" { return dnscc, errors.Error("no dnscrypt_config_file") } @@ -288,7 +288,7 @@ type dnsEncryption struct { } func getDNSEncryption() (de dnsEncryption) { - tlsConf := tlsConfigSettings{} + tlsConf := tlsConfiguration{} Context.tls.WriteDiskConfig(&tlsConf) diff --git a/internal/home/home.go b/internal/home/home.go index 6ff698d3..1c35fa04 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -512,7 +512,7 @@ func run(opts options, clientBuildFS fs.FS) { } config.Users = nil - Context.tls, err = newTLSManager(config.TLS) + Context.tls, err = newTLSManager(&config.TLS) if err != nil { log.Fatalf("initializing tls: %s", err) } @@ -817,7 +817,7 @@ func printWebAddrs(proto, addr string, port, betaPort int) { // printHTTPAddresses prints the IP addresses which user can use to access the // admin interface. proto is either schemeHTTP or schemeHTTPS. func printHTTPAddresses(proto string) { - tlsConf := tlsConfigSettings{} + tlsConf := tlsConfiguration{} if Context.tls != nil { Context.tls.WriteDiskConfig(&tlsConf) } diff --git a/internal/home/mobileconfig_test.go b/internal/home/mobileconfig_test.go index 3587154f..40e4dd1d 100644 --- a/internal/home/mobileconfig_test.go +++ b/internal/home/mobileconfig_test.go @@ -32,7 +32,11 @@ func setupDNSIPs(t testing.TB) { }, } - Context.tls = &tlsManager{} + var err error + Context.tls, err = newTLSManager(&tlsConfiguration{ + Enabled: true, + }) + require.NoError(t, err) } func TestHandleMobileConfigDoH(t *testing.T) { @@ -65,7 +69,11 @@ func TestHandleMobileConfigDoH(t *testing.T) { oldTLSConf := Context.tls t.Cleanup(func() { Context.tls = oldTLSConf }) - Context.tls = &tlsManager{conf: tlsConfigSettings{}} + var err error + Context.tls, err = newTLSManager(&tlsConfiguration{ + Enabled: true, + }) + require.NoError(t, err) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) require.NoError(t, err) @@ -137,7 +145,11 @@ func TestHandleMobileConfigDoT(t *testing.T) { oldTLSConf := Context.tls t.Cleanup(func() { Context.tls = oldTLSConf }) - Context.tls = &tlsManager{conf: tlsConfigSettings{}} + var err error + Context.tls, err = newTLSManager(&tlsConfiguration{ + Enabled: true, + }) + require.NoError(t, err) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) require.NoError(t, err) diff --git a/internal/home/tls.go b/internal/home/tls.go index 7fdd64d8..1c2036cc 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -8,42 +8,39 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" - "encoding/base64" - "encoding/json" "encoding/pem" "fmt" - "net/http" "os" "strings" "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" - "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" - "github.com/google/go-cmp/cmp" ) // tlsManager contains the current configuration and state of AdGuard Home TLS // encryption. type tlsManager struct { - // status is the current status of the configuration. It is never nil. - status *tlsConfigStatus - // certLastMod is the last modification time of the certificate file. certLastMod time.Time - confLock sync.Mutex - conf tlsConfigSettings + // status is the current status of the configuration. It is never nil. + status *tlsConfigStatus + + // confMu protects conf. + confMu *sync.RWMutex + + // conf is the current TLS configuration. + conf *tlsConfiguration } // newTLSManager initializes the TLS configuration. -func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) { +func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) { m = &tlsManager{ status: &tlsConfigStatus{}, + confMu: &sync.RWMutex{}, conf: conf, } @@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) { return m, nil } +// partialTLSConf returns a partial clone of the current TLS configuration. It +// is safe for concurrent use. +func (m *tlsManager) partialTLSConf() (conf *tlsConfiguration) { + m.confMu.RLock() + defer m.confMu.RUnlock() + + return m.conf.partialClone() +} + // load reloads the TLS configuration from files or data from the config file. +// load assumes that m.confLock is locked for writing. func (m *tlsManager) load() (err error) { - err = loadTLSConf(&m.conf, m.status) + err = loadTLSConf(m.conf, m.status) if err != nil { return fmt.Errorf("loading config: %w", err) } @@ -70,10 +77,8 @@ func (m *tlsManager) load() (err error) { } // WriteDiskConfig - write config -func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) { - m.confLock.Lock() - *conf = m.conf - m.confLock.Unlock() +func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) { + *conf = *m.partialTLSConf() } // setCertFileTime sets t.certLastMod from the certificate. If there are @@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() { func (m *tlsManager) start() { m.registerWebHandlers() - m.confLock.Lock() - tlsConf := m.conf - m.confLock.Unlock() - // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - Context.web.TLSConfigChanged(context.Background(), tlsConf) + Context.web.TLSConfigChanged(context.Background(), m.partialTLSConf()) } -// reload updates the configuration and restarts t. +// reload updates the configuration and restarts m. func (m *tlsManager) reload() { - m.confLock.Lock() - tlsConf := m.conf - m.confLock.Unlock() + m.confMu.Lock() + defer m.confMu.Unlock() - if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 { + if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 { return } - fi, err := os.Stat(tlsConf.CertificatePath) + fi, err := os.Stat(m.conf.CertificatePath) if err != nil { log.Error("tls: %s", err) @@ -132,9 +132,7 @@ func (m *tlsManager) reload() { log.Debug("tls: certificate file is modified") - m.confLock.Lock() err = m.load() - m.confLock.Unlock() if err != nil { log.Error("tls: reloading: %s", err) @@ -145,19 +143,15 @@ func (m *tlsManager) reload() { _ = reconfigureDNSServer() - m.confLock.Lock() - tlsConf = m.conf - m.confLock.Unlock() - // The background context is used because the TLSConfigChanged wraps context // with timeout on its own and shuts down the server, which handles current // request. - Context.web.TLSConfigChanged(context.Background(), tlsConf) + Context.web.TLSConfigChanged(context.Background(), m.conf) } // loadTLSConf loads and validates the TLS configuration. The returned error is // also set in status.WarningValidation. -func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { +func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) { defer func() { if err != nil { status.WarningValidation = err.Error() @@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) if tlsConf.CertificatePath != "" { - if tlsConf.CertificateChain != "" { - return errors.Error("certificate data and file can't be set together") - } - - tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath) + err = loadCert(tlsConf) if err != nil { - return fmt.Errorf("reading cert file: %w", err) + // Don't wrap the error, since it's informative enough as is. + return err } // Set status.ValidCert to true to signal the frontend that the @@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error } if tlsConf.PrivateKeyPath != "" { - if tlsConf.PrivateKey != "" { - return errors.Error("private key data and file can't be set together") - } - - tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath) + err = loadPKey(tlsConf) if err != nil { - return fmt.Errorf("reading key file: %w", err) + // Don't wrap the error, since it's informative enough as is. + return err } status.ValidKey = true @@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error return nil } -// tlsConfigStatus contains the status of a certificate chain and key pair. -type tlsConfigStatus struct { - // Subject is the subject of the first certificate in the chain. - Subject string `json:"subject,omitempty"` - - // Issuer is the issuer of the first certificate in the chain. - Issuer string `json:"issuer,omitempty"` - - // KeyType is the type of the private key. - KeyType string `json:"key_type,omitempty"` - - // NotBefore is the NotBefore field of the first certificate in the chain. - NotBefore time.Time `json:"not_before,omitempty"` - - // NotAfter is the NotAfter field of the first certificate in the chain. - NotAfter time.Time `json:"not_after,omitempty"` - - // WarningValidation is a validation warning message with the issue - // description. - WarningValidation string `json:"warning_validation,omitempty"` - - // DNSNames is the value of SubjectAltNames field of the first certificate - // in the chain. - DNSNames []string `json:"dns_names"` - - // ValidCert is true if the specified certificate chain is a valid chain of - // X509 certificates. - ValidCert bool `json:"valid_cert"` - - // ValidChain is true if the specified certificate chain is verified and - // issued by a known CA. - ValidChain bool `json:"valid_chain"` - - // ValidKey is true if the key is a valid private key. - ValidKey bool `json:"valid_key"` - - // ValidPair is true if both certificate and private key are correct for - // each other. - ValidPair bool `json:"valid_pair"` -} - -// tlsConfig is the TLS configuration and status response. -type tlsConfig struct { - *tlsConfigStatus `json:",inline"` - tlsConfigSettingsExt `json:",inline"` -} - -// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to -// ensure that clients don't send and receive previously saved private keys. -type tlsConfigSettingsExt struct { - tlsConfigSettings `json:",inline"` - - // PrivateKeySaved is true if the private key is saved as a string and omit - // key from answer. - PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"` -} - -func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { - m.confLock.Lock() - data := tlsConfig{ - tlsConfigSettingsExt: tlsConfigSettingsExt{ - tlsConfigSettings: m.conf, - }, - tlsConfigStatus: m.status, +// loadCert loads the certificate from file, if necessary. +func loadCert(tlsConf *tlsConfiguration) (err error) { + if tlsConf.CertificateChain != "" { + return errors.Error("certificate data and file can't be set together") } - m.confLock.Unlock() - marshalTLS(w, r, data) -} - -func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { - setts, err := unmarshalTLS(r) + tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath) if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) - - return + return fmt.Errorf("reading cert file: %w", err) } - if setts.PrivateKeySaved { - setts.PrivateKey = m.conf.PrivateKey - } - - if setts.Enabled { - err = validatePorts( - tcpPort(config.BindPort), - tcpPort(config.BetaBindPort), - tcpPort(setts.PortHTTPS), - tcpPort(setts.PortDNSOverTLS), - tcpPort(setts.PortDNSCrypt), - udpPort(config.DNS.Port), - udpPort(setts.PortDNSOverQUIC), - ) - if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) - - return - } - } - - if !webCheckPortAvailable(setts.PortHTTPS) { - aghhttp.Error( - r, - w, - http.StatusBadRequest, - "port %d is not available, cannot enable HTTPS on it", - setts.PortHTTPS, - ) - - return - } - - // Skip the error check, since we are only interested in the value of - // status.WarningValidation. - status := &tlsConfigStatus{} - _ = loadTLSConf(&setts.tlsConfigSettings, status) - resp := tlsConfig{ - tlsConfigSettingsExt: setts, - tlsConfigStatus: status, - } - - marshalTLS(w, r, resp) + return nil } -func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) { - m.confLock.Lock() - defer m.confLock.Unlock() - - // Reset the DNSCrypt data before comparing, since we currently do not - // accept these from the frontend. - // - // TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig. - newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile - newConf.PortDNSCrypt = m.conf.PortDNSCrypt - if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { - log.Info("tls config has changed, restarting https server") - restartHTTPS = true - } else { - log.Info("tls: config has not changed") +// loadPKey loads the private key from file, if necessary. +func loadPKey(tlsConf *tlsConfiguration) (err error) { + if tlsConf.PrivateKey != "" { + return errors.Error("private key data and file cannot be set together") } - // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf - m.conf.Enabled = newConf.Enabled - m.conf.ServerName = newConf.ServerName - m.conf.ForceHTTPS = newConf.ForceHTTPS - m.conf.PortHTTPS = newConf.PortHTTPS - m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS - m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC - m.conf.CertificateChain = newConf.CertificateChain - m.conf.CertificatePath = newConf.CertificatePath - m.conf.CertificateChainData = newConf.CertificateChainData - m.conf.PrivateKey = newConf.PrivateKey - m.conf.PrivateKeyPath = newConf.PrivateKeyPath - m.conf.PrivateKeyData = newConf.PrivateKeyData - m.status = status - - return restartHTTPS -} - -func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { - req, err := unmarshalTLS(r) + tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath) if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) - - return - } - - if req.PrivateKeySaved { - req.PrivateKey = m.conf.PrivateKey - } - - if req.Enabled { - err = validatePorts( - tcpPort(config.BindPort), - tcpPort(config.BetaBindPort), - tcpPort(req.PortHTTPS), - tcpPort(req.PortDNSOverTLS), - tcpPort(req.PortDNSCrypt), - udpPort(config.DNS.Port), - udpPort(req.PortDNSOverQUIC), - ) - if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) - - return - } - } - - // TODO(e.burkov): Investigate and perhaps check other ports. - if !webCheckPortAvailable(req.PortHTTPS) { - aghhttp.Error( - r, - w, - http.StatusBadRequest, - "port %d is not available, cannot enable https on it", - req.PortHTTPS, - ) - - return - } - - status := &tlsConfigStatus{} - err = loadTLSConf(&req.tlsConfigSettings, status) - if err != nil { - resp := tlsConfig{ - tlsConfigSettingsExt: req, - tlsConfigStatus: status, - } - - marshalTLS(w, r, resp) - - return - } - - restartHTTPS := m.setConfig(req.tlsConfigSettings, status) - m.setCertFileTime() - onConfigModified() - - err = reconfigureDNSServer() - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) - - return - } - - resp := tlsConfig{ - tlsConfigSettingsExt: req, - tlsConfigStatus: m.status, - } - - marshalTLS(w, r, resp) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - - // The background context is used because the TLSConfigChanged wraps context - // with timeout on its own and shuts down the server, which handles current - // request. It is also should be done in a separate goroutine due to the - // same reason. - if restartHTTPS { - go func() { - Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings) - }() - } -} - -// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home -// DNS protocols. -func validatePorts( - bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort, - dnsPort, doqPort udpPort, -) (err error) { - tcpPorts := aghalg.UniqChecker[tcpPort]{} - addPorts( - tcpPorts, - tcpPort(bindPort), - tcpPort(betaBindPort), - tcpPort(dohPort), - tcpPort(dotPort), - tcpPort(dnscryptTCPPort), - ) - - err = tcpPorts.Validate() - if err != nil { - return fmt.Errorf("validating tcp ports: %w", err) - } - - udpPorts := aghalg.UniqChecker[udpPort]{} - addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort)) - - err = udpPorts.Validate() - if err != nil { - return fmt.Errorf("validating udp ports: %w", err) + return fmt.Errorf("reading key file: %w", err) } return nil @@ -696,61 +435,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error) return nil, "", errors.Error("tls: failed to parse private key") } - -// unmarshalTLS handles base64-encoded certificates transparently -func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) { - data := tlsConfigSettingsExt{} - err := json.NewDecoder(r.Body).Decode(&data) - if err != nil { - return data, fmt.Errorf("failed to parse new TLS config json: %w", err) - } - - if data.CertificateChain != "" { - var cert []byte - cert, err = base64.StdEncoding.DecodeString(data.CertificateChain) - if err != nil { - return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err) - } - - data.CertificateChain = string(cert) - if data.CertificatePath != "" { - return data, fmt.Errorf("certificate data and file can't be set together") - } - } - - if data.PrivateKey != "" { - var key []byte - key, err = base64.StdEncoding.DecodeString(data.PrivateKey) - if err != nil { - return data, fmt.Errorf("failed to base64-decode private key: %w", err) - } - - data.PrivateKey = string(key) - if data.PrivateKeyPath != "" { - return data, fmt.Errorf("private key data and file can't be set together") - } - } - - return data, nil -} - -func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) { - if data.CertificateChain != "" { - encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain)) - data.CertificateChain = encoded - } - - if data.PrivateKey != "" { - data.PrivateKeySaved = true - data.PrivateKey = "" - } - - _ = aghhttp.WriteJSONResponse(w, r, data) -} - -// registerWebHandlers registers HTTP handlers for TLS configuration. -func (m *tlsManager) registerWebHandlers() { - httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus) - httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure) - httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate) -} diff --git a/internal/home/tlshttp.go b/internal/home/tlshttp.go new file mode 100644 index 00000000..de4c84ef --- /dev/null +++ b/internal/home/tlshttp.go @@ -0,0 +1,350 @@ +package home + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/golibs/log" + "github.com/google/go-cmp/cmp" +) + +// Encryption Settings HTTP API + +// tlsConfigStatus contains the status of a certificate chain and key pair. +type tlsConfigStatus struct { + // Subject is the subject of the first certificate in the chain. + Subject string `json:"subject,omitempty"` + + // Issuer is the issuer of the first certificate in the chain. + Issuer string `json:"issuer,omitempty"` + + // KeyType is the type of the private key. + KeyType string `json:"key_type,omitempty"` + + // NotBefore is the NotBefore field of the first certificate in the chain. + NotBefore time.Time `json:"not_before,omitempty"` + + // NotAfter is the NotAfter field of the first certificate in the chain. + NotAfter time.Time `json:"not_after,omitempty"` + + // WarningValidation is a validation warning message with the issue + // description. + WarningValidation string `json:"warning_validation,omitempty"` + + // DNSNames is the value of SubjectAltNames field of the first certificate + // in the chain. + DNSNames []string `json:"dns_names"` + + // ValidCert is true if the specified certificate chain is a valid chain of + // X509 certificates. + ValidCert bool `json:"valid_cert"` + + // ValidChain is true if the specified certificate chain is verified and + // issued by a known CA. + ValidChain bool `json:"valid_chain"` + + // ValidKey is true if the key is a valid private key. + ValidKey bool `json:"valid_key"` + + // ValidPair is true if both certificate and private key are correct for + // each other. + ValidPair bool `json:"valid_pair"` +} + +// tlsConfigResp is the TLS configuration and status response. +type tlsConfigResp struct { + *tlsConfigStatus + *tlsConfiguration + + // PrivateKeySaved is true if the private key is saved as a string and omit + // key from answer. + PrivateKeySaved bool `yaml:"-" json:"private_key_saved"` +} + +// tlsConfigReq is the TLS configuration request. +type tlsConfigReq struct { + tlsConfiguration + + // PrivateKeySaved is true if the private key is saved as a string and omit + // key from answer. + PrivateKeySaved bool `yaml:"-" json:"private_key_saved"` +} + +func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { + resp := &tlsConfigResp{ + tlsConfigStatus: m.status, + tlsConfiguration: m.partialTLSConf(), + } + + marshalTLS(w, r, resp) +} + +func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { + req, err := unmarshalTLS(r) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + + return + } + + if req.PrivateKeySaved { + req.PrivateKey = m.conf.PrivateKey + } + + if req.Enabled { + err = validatePorts( + tcpPort(config.BindPort), + tcpPort(config.BetaBindPort), + tcpPort(req.PortHTTPS), + tcpPort(req.PortDNSOverTLS), + tcpPort(req.PortDNSCrypt), + udpPort(config.DNS.Port), + udpPort(req.PortDNSOverQUIC), + ) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + } + + if !webCheckPortAvailable(req.PortHTTPS) { + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "port %d is not available, cannot enable HTTPS on it", + req.PortHTTPS, + ) + + return + } + + // Skip the error check, since we are only interested in the value of + // status.WarningValidation. + resp := &tlsConfigResp{ + tlsConfigStatus: &tlsConfigStatus{}, + tlsConfiguration: &req.tlsConfiguration, + } + _ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus) + + marshalTLS(w, r, resp) +} + +// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home +// DNS protocols. +func validatePorts( + bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort, + dnsPort, doqPort udpPort, +) (err error) { + tcpPorts := aghalg.UniqChecker[tcpPort]{} + addPorts( + tcpPorts, + tcpPort(bindPort), + tcpPort(betaBindPort), + tcpPort(dohPort), + tcpPort(dotPort), + tcpPort(dnscryptTCPPort), + ) + + err = tcpPorts.Validate() + if err != nil { + return fmt.Errorf("validating tcp ports: %w", err) + } + + udpPorts := aghalg.UniqChecker[udpPort]{} + addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort)) + + err = udpPorts.Validate() + if err != nil { + return fmt.Errorf("validating udp ports: %w", err) + } + + return nil +} + +func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { + req, err := unmarshalTLS(r) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) + + return + } + + if req.PrivateKeySaved { + req.PrivateKey = m.partialTLSConf().PrivateKey + } + + if req.Enabled { + err = validatePorts( + tcpPort(config.BindPort), + tcpPort(config.BetaBindPort), + tcpPort(req.PortHTTPS), + tcpPort(req.PortDNSOverTLS), + tcpPort(req.PortDNSCrypt), + udpPort(config.DNS.Port), + udpPort(req.PortDNSOverQUIC), + ) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + + return + } + } + + // TODO(e.burkov): Investigate and perhaps check other ports. + if !webCheckPortAvailable(req.PortHTTPS) { + aghhttp.Error( + r, + w, + http.StatusBadRequest, + "port %d is not available, cannot enable https on it", + req.PortHTTPS, + ) + + return + } + + resp := &tlsConfigResp{ + tlsConfigStatus: &tlsConfigStatus{}, + tlsConfiguration: &req.tlsConfiguration, + } + err = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus) + if err != nil { + marshalTLS(w, r, resp) + + return + } + + restartRequired := m.setConf(resp) + m.setCertFileTime() + onConfigModified() + + err = reconfigureDNSServer() + if err != nil { + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) + + return + } + + resp.tlsConfiguration = m.partialTLSConf() + marshalTLS(w, r, resp) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // The background context is used because the TLSConfigChanged wraps context + // with timeout on its own and shuts down the server, which handles current + // request. It is also should be done in a separate goroutine due to the + // same reason. + if restartRequired { + go func() { + Context.web.TLSConfigChanged(context.Background(), resp.tlsConfiguration) + }() + } +} + +// setConf sets the necessary values from the new configuration. +func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) { + m.confMu.Lock() + defer m.confMu.Unlock() + + // Reset the DNSCrypt data before comparing, since we currently do not + // accept these from the frontend. + // + // TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig. + newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile + newConf.PortDNSCrypt = m.conf.PortDNSCrypt + if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { + log.Info("tls: config has changed, restarting https server") + restartRequired = true + } else { + log.Info("tls: config has not changed") + } + + // Do not just write "m.conf = *newConf.tlsConfiguration", because all other + // members of m.conf must be preserved. + m.conf.Enabled = newConf.Enabled + m.conf.ServerName = newConf.ServerName + m.conf.ForceHTTPS = newConf.ForceHTTPS + m.conf.PortHTTPS = newConf.PortHTTPS + m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS + m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC + + m.conf.CertificateChain = newConf.CertificateChain + m.conf.CertificatePath = newConf.CertificatePath + m.conf.CertificateChainData = newConf.CertificateChainData + m.conf.PrivateKey = newConf.PrivateKey + m.conf.PrivateKeyPath = newConf.PrivateKeyPath + m.conf.PrivateKeyData = newConf.PrivateKeyData + + m.status = newConf.tlsConfigStatus + + return restartRequired +} + +// marshalTLS handles Base64-encoded certificates transparently. +func marshalTLS(w http.ResponseWriter, r *http.Request, conf *tlsConfigResp) { + if conf.CertificateChain != "" { + encoded := base64.StdEncoding.EncodeToString([]byte(conf.CertificateChain)) + conf.CertificateChain = encoded + } + + if conf.PrivateKey != "" { + conf.PrivateKeySaved = true + conf.PrivateKey = "" + } + + _ = aghhttp.WriteJSONResponse(w, r, conf) +} + +// unmarshalTLS handles Base64-encoded certificates transparently. +func unmarshalTLS(r *http.Request) (req *tlsConfigReq, err error) { + req = &tlsConfigReq{} + err = json.NewDecoder(r.Body).Decode(req) + if err != nil { + return nil, fmt.Errorf("parsing tls config: %w", err) + } + + if req.CertificateChain != "" { + var cert []byte + cert, err = base64.StdEncoding.DecodeString(req.CertificateChain) + if err != nil { + return nil, fmt.Errorf("failed to base64-decode certificate chain: %w", err) + } + + req.CertificateChain = string(cert) + if req.CertificatePath != "" { + return nil, fmt.Errorf("certificate data and file can't be set together") + } + } + + if req.PrivateKey != "" { + var key []byte + key, err = base64.StdEncoding.DecodeString(req.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to base64-decode private key: %w", err) + } + + req.PrivateKey = string(key) + if req.PrivateKeyPath != "" { + return nil, fmt.Errorf("private key data and file can't be set together") + } + } + + return req, nil +} + +// registerWebHandlers registers HTTP handlers for TLS configuration. +func (m *tlsManager) registerWebHandlers() { + httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus) + httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure) + httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate) +} diff --git a/internal/home/web.go b/internal/home/web.go index 7836355f..a393b5aa 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -143,7 +143,7 @@ func webCheckPortAvailable(port int) (ok bool) { // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server // if necessary. -func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { +func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf *tlsConfiguration) { log.Debug("web: applying new tls configuration") web.conf.PortHTTPS = tlsConf.PortHTTPS web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)