diff --git a/internal/home/home.go b/internal/home/home.go index 7777e6dd..f774a06a 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -664,7 +664,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH globalContext.auth, err = initUsers() fatalOnError(err) - tlsMgr, err := newTLSManager(config.TLS, config.DNS.ServePlainDNS) + tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") + tlsMgr, err := newTLSManager(ctx, tlsMgrLogger, config.TLS, config.DNS.ServePlainDNS) if err != nil { log.Error("initializing tls: %s", err) onConfigModified() diff --git a/internal/home/signal.go b/internal/home/signal.go index 824e62dd..638d3632 100644 --- a/internal/home/signal.go +++ b/internal/home/signal.go @@ -116,6 +116,6 @@ func (h *signalHandler) reloadConfig(ctx context.Context) { } if h.tlsManager != nil { - h.tlsManager.reload() + h.tlsManager.reload(ctx) } } diff --git a/internal/home/tls.go b/internal/home/tls.go index 2882e8ef..778c674b 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -12,6 +12,7 @@ import ( "encoding/json" "encoding/pem" "fmt" + "log/slog" "net/http" "os" "strings" @@ -23,13 +24,17 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/c2h5oh/datasize" "github.com/google/go-cmp/cmp" ) // tlsManager contains the current configuration and state of AdGuard Home TLS // encryption. type tlsManager struct { + // logger is used for logging the operation of the TLS Manager. + logger *slog.Logger + // status is the current status of the configuration. It is never nil. status *tlsConfigStatus @@ -45,31 +50,38 @@ type tlsManager struct { // newTLSManager initializes the manager of TLS configuration. m is always // non-nil while any returned error indicates that the TLS configuration isn't -// valid. Thus TLS may be initialized later, e.g. via the web UI. -func newTLSManager(conf tlsConfigSettings, servePlainDNS bool) (m *tlsManager, err error) { +// valid. Thus TLS may be initialized later, e.g. via the web UI. logger must +// not be nil. +func newTLSManager( + ctx context.Context, + logger *slog.Logger, + conf tlsConfigSettings, + servePlainDNS bool, +) (m *tlsManager, err error) { m = &tlsManager{ + logger: logger, status: &tlsConfigStatus{}, conf: conf, servePlainDNS: servePlainDNS, } if m.conf.Enabled { - err = m.load() + err = m.load(ctx) if err != nil { m.conf.Enabled = false return m, err } - m.setCertFileTime() + m.setCertFileTime(ctx) } return m, nil } // load reloads the TLS configuration from files or data from the config file. -func (m *tlsManager) load() (err error) { - err = loadTLSConf(&m.conf, m.status) +func (m *tlsManager) load(ctx context.Context) (err error) { + err = m.loadTLSConf(ctx, &m.conf, m.status) if err != nil { return fmt.Errorf("loading config: %w", err) } @@ -84,16 +96,16 @@ func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) { m.confLock.Unlock() } -// setCertFileTime sets t.certLastMod from the certificate. If there are -// errors, setCertFileTime logs them. -func (m *tlsManager) setCertFileTime() { +// setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there +// are errors, setCertFileTime logs them. +func (m *tlsManager) setCertFileTime(ctx context.Context) { if len(m.conf.CertificatePath) == 0 { return } fi, err := os.Stat(m.conf.CertificatePath) if err != nil { - log.Error("tls: looking up certificate path: %s", err) + m.logger.ErrorContext(ctx, "looking up certificate path", slogutil.KeyError, err) return } @@ -117,8 +129,8 @@ func (m *tlsManager) start(_ context.Context) { globalContext.web.tlsConfigChanged(context.Background(), tlsConf) } -// reload updates the configuration and restarts t. -func (m *tlsManager) reload() { +// reload updates the configuration and restarts the TLS manager. +func (m *tlsManager) reload(ctx context.Context) { m.confLock.Lock() tlsConf := m.conf m.confLock.Unlock() @@ -127,33 +139,37 @@ func (m *tlsManager) reload() { return } - fi, err := os.Stat(tlsConf.CertificatePath) + certPath := tlsConf.CertificatePath + fi, err := os.Stat(certPath) if err != nil { - log.Error("tls: %s", err) + m.logger.ErrorContext(ctx, "checking certificate file", slogutil.KeyError, err) return } if fi.ModTime().UTC().Equal(m.certLastMod) { - log.Debug("tls: certificate file isn't modified") + m.logger.InfoContext(ctx, "certificate file is not modified") return } - log.Debug("tls: certificate file is modified") + m.logger.InfoContext(ctx, "certificate file is modified") m.confLock.Lock() - err = m.load() + err = m.load(ctx) m.confLock.Unlock() if err != nil { - log.Error("tls: reloading: %s", err) + m.logger.ErrorContext(ctx, "reloading", slogutil.KeyError, err) return } m.certLastMod = fi.ModTime().UTC() - _ = m.reconfigureDNSServer() + err = m.reconfigureDNSServer() + if err != nil { + m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) + } m.confLock.Lock() tlsConf = m.conf @@ -192,7 +208,11 @@ func (m *tlsManager) reconfigureDNSServer() (err error) { // loadTLSConf loads and validates the TLS configuration. The returned error is // also set in status.WarningValidation. -func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { +func (m *tlsManager) loadTLSConf( + ctx context.Context, + tlsConf *tlsConfigSettings, + status *tlsConfigStatus, +) (err error) { defer func() { if err != nil { status.WarningValidation = err.Error() @@ -215,7 +235,8 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error return err } - err = validateCertificates( + err = m.validateCertificates( + ctx, status, tlsConf.CertificateChainData, tlsConf.PrivateKeyData, @@ -367,7 +388,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { // Skip the error check, since we are only interested in the value of // status.WarningValidation. status := &tlsConfigStatus{} - _ = loadTLSConf(&setts.tlsConfigSettings, status) + _ = m.loadTLSConf(r.Context(), &setts.tlsConfigSettings, status) resp := tlsConfig{ tlsConfigSettingsExt: setts, tlsConfigStatus: status, @@ -378,6 +399,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { // setConfig updates manager conf with the given one. func (m *tlsManager) setConfig( + ctx context.Context, newConf tlsConfigSettings, status *tlsConfigStatus, servePlain aghalg.NullBool, @@ -392,10 +414,10 @@ func (m *tlsManager) setConfig( newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile newConf.PortDNSCrypt = m.conf.PortDNSCrypt if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { - log.Info("tls config has changed, restarting https server") + m.logger.InfoContext(ctx, "config has changed, restarting https server") restartHTTPS = true } else { - log.Info("tls: config has not changed") + m.logger.InfoContext(ctx, "config has not changed") } // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf @@ -423,6 +445,8 @@ func (m *tlsManager) setConfig( // handleTLSConfigure is the handler for the POST /control/tls/configure HTTP // API. func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + req, err := unmarshalTLS(r) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) @@ -441,7 +465,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) } status := &tlsConfigStatus{} - err = loadTLSConf(&req.tlsConfigSettings, status) + err = m.loadTLSConf(ctx, &req.tlsConfigSettings, status) if err != nil { resp := tlsConfig{ tlsConfigSettingsExt: req, @@ -453,8 +477,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) return } - restartHTTPS := m.setConfig(req.tlsConfigSettings, status, req.ServePlainDNS) - m.setCertFileTime() + restartHTTPS := m.setConfig(ctx, req.tlsConfigSettings, status, req.ServePlainDNS) + m.setCertFileTime(ctx) if req.ServePlainDNS != aghalg.NBNull { func() { @@ -469,6 +493,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) err = m.reconfigureDNSServer() if err != nil { + m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) return @@ -555,15 +581,27 @@ func validatePorts( // validateCertChain verifies certs using the first as the main one and others // as intermediate. srvName stands for the expected DNS name. -func validateCertChain(certs []*x509.Certificate, srvName string) (err error) { +func (m *tlsManager) validateCertChain( + ctx context.Context, + certs []*x509.Certificate, + srvName string, +) (err error) { main, others := certs[0], certs[1:] pool := x509.NewCertPool() for _, cert := range others { - log.Info("tls: got an intermediate cert") pool.AddCert(cert) } + othersLen := len(others) + if othersLen > 0 { + m.logger.InfoContext( + ctx, + "verifying certificate chain: got an intermediate cert", + "num", othersLen, + ) + } + opts := x509.VerifyOptions{ DNSName: srvName, Roots: globalContext.tlsRoots, @@ -577,15 +615,18 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) { return nil } -// errNoIPInCert is the error that is returned from [parseCertChain] if the leaf -// certificate doesn't contain IPs. +// errNoIPInCert is the error that is returned from [tlsManager.parseCertChain] +// if the leaf certificate doesn't contain IPs. const errNoIPInCert errors.Error = `certificates has no IP addresses; ` + `DNS-over-TLS won't be advertised via DDR` // parseCertChain parses the certificate chain from raw data, and returns it. // If ok is true, the returned error, if any, is not critical. -func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err error) { - log.Debug("tls: got certificate chain: %d bytes", len(chain)) +func (m *tlsManager) parseCertChain( + ctx context.Context, + chain []byte, +) (parsedCerts []*x509.Certificate, ok bool, err error) { + m.logger.DebugContext(ctx, "parsing certificate chain", "size", datasize.ByteSize(len(chain))) var certs []*pem.Block for decoded, pemblock := pem.Decode(chain); decoded != nil; { @@ -601,7 +642,7 @@ func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err return nil, false, err } - log.Info("tls: number of certs: %d", len(parsedCerts)) + m.logger.InfoContext(ctx, "parsing multiple pem certificates", "num", len(parsedCerts)) if !aghtls.CertificateHasIP(parsedCerts[0]) { err = errNoIPInCert @@ -668,7 +709,8 @@ func validatePKey(pkey []byte) (keyType string, err error) { // validateCertificates processes certificate data and its private key. status // must not be nil, since it's used to accumulate the validation results. Other // parameters are optional. -func validateCertificates( +func (m *tlsManager) validateCertificates( + ctx context.Context, status *tlsConfigStatus, certChain []byte, pkey []byte, @@ -677,7 +719,7 @@ func validateCertificates( // Check only the public certificate separately from the key. if len(certChain) > 0 { var certs []*x509.Certificate - certs, status.ValidCert, err = parseCertChain(certChain) + certs, status.ValidCert, err = m.parseCertChain(ctx, certChain) if !status.ValidCert { // Don't wrap the error, since it's informative enough as is. return err @@ -690,7 +732,7 @@ func validateCertificates( status.NotBefore = mainCert.NotBefore status.DNSNames = mainCert.DNSNames - if chainErr := validateCertChain(certs, serverName); chainErr != nil { + if chainErr := m.validateCertChain(ctx, certs, serverName); chainErr != nil { // Let self-signed certs through and don't return this error to set // its message into the status.WarningValidation afterwards. err = chainErr diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index 12d2bce4..e67393b4 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -63,9 +63,15 @@ kXS9jgARhhiWXJrk -----END PRIVATE KEY-----`) func TestValidateCertificates(t *testing.T) { + ctx := testutil.ContextWithTimeout(t, testTimeout) + logger := slogutil.NewDiscardLogger() + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{}, false) + require.NoError(t, err) + t.Run("bad_certificate", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, []byte("bad cert"), nil, "") + err = m.validateCertificates(ctx, status, []byte("bad cert"), nil, "") testutil.AssertErrorMsg(t, "empty certificate", err) assert.False(t, status.ValidCert) assert.False(t, status.ValidChain) @@ -73,14 +79,14 @@ func TestValidateCertificates(t *testing.T) { t.Run("bad_private_key", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, nil, []byte("bad priv key"), "") + err = m.validateCertificates(ctx, status, nil, []byte("bad priv key"), "") testutil.AssertErrorMsg(t, "no valid keys were found", err) assert.False(t, status.ValidKey) }) t.Run("valid", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, testCertChainData, testPrivateKeyData, "") + err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "") assert.Error(t, err) notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC) @@ -230,7 +236,7 @@ func TestTLSManager_Reload(t *testing.T) { certDER, key := newCertAndKey(t, snBefore) writeCertAndKey(t, certDER, certPath, key, keyPath) - m, err := newTLSManager(tlsConfigSettings{ + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificatePath: certPath, @@ -246,14 +252,20 @@ func TestTLSManager_Reload(t *testing.T) { certDER, key = newCertAndKey(t, snAfter) writeCertAndKey(t, certDER, certPath, key, keyPath) - m.reload() + m.reload(ctx) m.WriteDiskConfig(conf) assertCertSerialNumber(t, conf, snAfter) } func TestTLSManager_HandleTLSStatus(t *testing.T) { - m, err := newTLSManager(tlsConfigSettings{ + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificateChain: string(testCertChainData), @@ -356,7 +368,7 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) require.NoError(t, err) - m, err := newTLSManager(tlsConfigSettings{ + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificateChain: string(testCertChainData), @@ -443,7 +455,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { writeCertAndKey(t, certDER, certPath, key, keyPath) // Initialize the TLS manager and assert its configuration. - m, err := newTLSManager(tlsConfigSettings{ + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ Enabled: true, TLSConfig: dnsforward.TLSConfig{ CertificatePath: certPath,