Compare commits

...

3 Commits

Author SHA1 Message Date
Ainar Garipov
0c7d56dca3 Merge branch 'master' into 4927-refactor-tls 2022-11-22 17:10:40 +03:00
Ainar Garipov
f36efa26a4 home: refactor more 2022-11-21 19:45:18 +03:00
Ainar Garipov
a8850059db home: refactor tls 2022-11-21 19:05:49 +03:00
9 changed files with 467 additions and 386 deletions

View File

@@ -145,7 +145,8 @@ type FilteringConfig struct {
IpsetListFileName string `yaml:"ipset_file"` IpsetListFileName string `yaml:"ipset_file"`
} }
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, DNS-over-TLS,
// and DNS-over-QUIC.
type TLSConfig struct { type TLSConfig struct {
cert tls.Certificate cert tls.Certificate

View File

@@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
"golang.org/x/exp/slices"
yaml "gopkg.in/yaml.v3" yaml "gopkg.in/yaml.v3"
) )
@@ -113,8 +114,8 @@ type configuration struct {
// An active session is automatically refreshed once a day. // An active session is automatically refreshed once a day.
WebSessionTTLHours uint32 `yaml:"web_session_ttl"` WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"` DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"` TLS tlsConfiguration `yaml:"tls"`
// Filters reflects the filters from [filtering.Config]. It's cloned to the // Filters reflects the filters from [filtering.Config]. It's cloned to the
// config used in the filtering module at the startup. Afterwards it's // config used in the filtering module at the startup. Afterwards it's
@@ -199,7 +200,8 @@ type dnsConfig struct {
UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"` UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
} }
type tlsConfigSettings struct { // tlsConfiguration is the on-disk TLS configuration.
type tlsConfiguration struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
@@ -223,6 +225,29 @@ type tlsConfigSettings struct {
dnsforward.TLSConfig `yaml:",inline" json:",inline"` dnsforward.TLSConfig `yaml:",inline" json:",inline"`
} }
// cloneForEncoding returns a clone of c with all top-level fields of c and all
// exported and YAML-encoded fields of c.TLSConfig cloned.
//
// TODO(a.garipov): This is better than races, but still not good enough.
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
if c == nil {
return nil
}
v := *c
cloned = &v
cloned.TLSConfig = dnsforward.TLSConfig{
CertificateChain: c.CertificateChain,
PrivateKey: c.PrivateKey,
CertificatePath: c.CertificatePath,
PrivateKeyPath: c.PrivateKeyPath,
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
StrictSNICheck: c.StrictSNICheck,
}
return cloned
}
// config is the global configuration structure. // config is the global configuration structure.
// //
// TODO(a.garipov, e.burkov): This global is awful and must be removed. // TODO(a.garipov, e.burkov): This global is awful and must be removed.
@@ -273,7 +298,7 @@ var config = &configuration{
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true, UsePrivateRDNS: true,
}, },
TLS: tlsConfigSettings{ TLS: tlsConfiguration{
PortHTTPS: defaultPortHTTPS, PortHTTPS: defaultPortHTTPS,
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC, PortDNSOverQUIC: defaultPortQUIC,
@@ -442,7 +467,7 @@ func (c *configuration) write() (err error) {
} }
if Context.tls != nil { if Context.tls != nil {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf) Context.tls.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf config.TLS = tlsConf
} }

View File

@@ -154,7 +154,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
return nil return nil
} }
tlsConf := &tlsConfigSettings{} tlsConf := &tlsConfiguration{}
Context.tls.WriteDiskConfig(tlsConf) Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true canUpdate := true
@@ -172,7 +172,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration // tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used. // indicates that privileged ports are used.
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) { func tlsConfUsesPrivilegedPorts(c *tlsConfiguration) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024) return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
} }

View File

@@ -205,7 +205,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
OnDNSRequest: onDNSRequest, OnDNSRequest: onDNSRequest,
} }
tlsConf := tlsConfigSettings{} tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf) Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled { if tlsConf.Enabled {
newConf.TLSConfig = tlsConf.TLSConfig newConf.TLSConfig = tlsConf.TLSConfig
@@ -250,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
return newConf, nil return newConf, nil
} }
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) { func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfiguration) (dnscc dnsforward.DNSCryptConfig, err error) {
if tlsConf.DNSCryptConfigFile == "" { if tlsConf.DNSCryptConfigFile == "" {
return dnscc, errors.Error("no dnscrypt_config_file") return dnscc, errors.Error("no dnscrypt_config_file")
} }
@@ -288,7 +288,7 @@ type dnsEncryption struct {
} }
func getDNSEncryption() (de dnsEncryption) { func getDNSEncryption() (de dnsEncryption) {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf) Context.tls.WriteDiskConfig(&tlsConf)

View File

@@ -512,7 +512,7 @@ func run(opts options, clientBuildFS fs.FS) {
} }
config.Users = nil config.Users = nil
Context.tls, err = newTLSManager(config.TLS) Context.tls, err = newTLSManager(&config.TLS)
if err != nil { if err != nil {
log.Fatalf("initializing tls: %s", err) log.Fatalf("initializing tls: %s", err)
} }
@@ -817,7 +817,7 @@ func printWebAddrs(proto, addr string, port, betaPort int) {
// printHTTPAddresses prints the IP addresses which user can use to access the // printHTTPAddresses prints the IP addresses which user can use to access the
// admin interface. proto is either schemeHTTP or schemeHTTPS. // admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) { func printHTTPAddresses(proto string) {
tlsConf := tlsConfigSettings{} tlsConf := tlsConfiguration{}
if Context.tls != nil { if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf) Context.tls.WriteDiskConfig(&tlsConf)
} }

View File

@@ -32,7 +32,11 @@ func setupDNSIPs(t testing.TB) {
}, },
} }
Context.tls = &tlsManager{} var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
} }
func TestHandleMobileConfigDoH(t *testing.T) { func TestHandleMobileConfigDoH(t *testing.T) {
@@ -65,7 +69,11 @@ func TestHandleMobileConfigDoH(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}} var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)
@@ -137,7 +145,11 @@ func TestHandleMobileConfigDoT(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}} var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -8,42 +8,39 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net/http"
"os" "os"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/google/go-cmp/cmp"
) )
// tlsManager contains the current configuration and state of AdGuard Home TLS // tlsManager contains the current configuration and state of AdGuard Home TLS
// encryption. // encryption.
type tlsManager struct { type tlsManager struct {
// status is the current status of the configuration. It is never nil. // mu protects all fields.
status *tlsConfigStatus mu *sync.RWMutex
// certLastMod is the last modification time of the certificate file. // certLastMod is the last modification time of the certificate file.
certLastMod time.Time certLastMod time.Time
confLock sync.Mutex // status is the current status of the configuration. It is never nil.
conf tlsConfigSettings status *tlsConfigStatus
// conf is the current TLS configuration.
conf *tlsConfiguration
} }
// newTLSManager initializes the TLS configuration. // newTLSManager initializes the TLS configuration.
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) { func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
m = &tlsManager{ m = &tlsManager{
status: &tlsConfigStatus{}, status: &tlsConfigStatus{},
mu: &sync.RWMutex{},
conf: conf, conf: conf,
} }
@@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
return m, nil return m, nil
} }
// confForEncoding returns a partial clone of the current TLS configuration. It
// is safe for concurrent use.
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.conf.cloneForEncoding()
}
// load reloads the TLS configuration from files or data from the config file. // load reloads the TLS configuration from files or data from the config file.
// m.mu is expected to be locked for writing.
func (m *tlsManager) load() (err error) { func (m *tlsManager) load() (err error) {
err = loadTLSConf(&m.conf, m.status) err = loadTLSConf(m.conf, m.status)
if err != nil { if err != nil {
return fmt.Errorf("loading config: %w", err) return fmt.Errorf("loading config: %w", err)
} }
@@ -70,14 +77,12 @@ func (m *tlsManager) load() (err error) {
} }
// WriteDiskConfig - write config // WriteDiskConfig - write config
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) { func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
m.confLock.Lock() *conf = *m.confForEncoding()
*conf = m.conf
m.confLock.Unlock()
} }
// setCertFileTime sets t.certLastMod from the certificate. If there are // setCertFileTime sets t.certLastMod from the certificate. If there are
// errors, setCertFileTime logs them. // errors, setCertFileTime logs them. mu is expected to be locked for writing.
func (m *tlsManager) setCertFileTime() { func (m *tlsManager) setCertFileTime() {
if len(m.conf.CertificatePath) == 0 { if len(m.conf.CertificatePath) == 0 {
return return
@@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() {
func (m *tlsManager) start() { func (m *tlsManager) start() {
m.registerWebHandlers() m.registerWebHandlers()
m.confLock.Lock()
tlsConf := m.conf
m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
} }
// reload updates the configuration and restarts t. // reload updates the configuration and restarts m.
func (m *tlsManager) reload() { func (m *tlsManager) reload() {
m.confLock.Lock() m.mu.Lock()
tlsConf := m.conf defer m.mu.Unlock()
m.confLock.Unlock()
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 { if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
return return
} }
fi, err := os.Stat(tlsConf.CertificatePath) fi, err := os.Stat(m.conf.CertificatePath)
if err != nil { if err != nil {
log.Error("tls: %s", err) log.Error("tls: %s", err)
@@ -132,9 +132,7 @@ func (m *tlsManager) reload() {
log.Debug("tls: certificate file is modified") log.Debug("tls: certificate file is modified")
m.confLock.Lock()
err = m.load() err = m.load()
m.confLock.Unlock()
if err != nil { if err != nil {
log.Error("tls: reloading: %s", err) log.Error("tls: reloading: %s", err)
@@ -145,19 +143,15 @@ func (m *tlsManager) reload() {
_ = reconfigureDNSServer() _ = reconfigureDNSServer()
m.confLock.Lock()
tlsConf = m.conf
m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.TLSConfigChanged(context.Background(), m.conf)
} }
// loadTLSConf loads and validates the TLS configuration. The returned error is // loadTLSConf loads and validates the TLS configuration. The returned error is
// also set in status.WarningValidation. // also set in status.WarningValidation.
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
status.WarningValidation = err.Error() status.WarningValidation = err.Error()
@@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.CertificatePath != "" { if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" { err = loadCert(tlsConf)
return errors.Error("certificate data and file can't be set together")
}
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
if err != nil { if err != nil {
return fmt.Errorf("reading cert file: %w", err) // Don't wrap the error, since it's informative enough as is.
return err
} }
// Set status.ValidCert to true to signal the frontend that the // Set status.ValidCert to true to signal the frontend that the
@@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
} }
if tlsConf.PrivateKeyPath != "" { if tlsConf.PrivateKeyPath != "" {
if tlsConf.PrivateKey != "" { err = loadPKey(tlsConf)
return errors.Error("private key data and file can't be set together")
}
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
if err != nil { if err != nil {
return fmt.Errorf("reading key file: %w", err) // Don't wrap the error, since it's informative enough as is.
return err
} }
status.ValidKey = true status.ValidKey = true
@@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
return nil return nil
} }
// tlsConfigStatus contains the status of a certificate chain and key pair. // loadCert loads the certificate from file, if necessary.
type tlsConfigStatus struct { func loadCert(tlsConf *tlsConfiguration) (err error) {
// Subject is the subject of the first certificate in the chain. if tlsConf.CertificateChain != "" {
Subject string `json:"subject,omitempty"` return errors.Error("certificate data and file can't be set together")
// Issuer is the issuer of the first certificate in the chain.
Issuer string `json:"issuer,omitempty"`
// KeyType is the type of the private key.
KeyType string `json:"key_type,omitempty"`
// NotBefore is the NotBefore field of the first certificate in the chain.
NotBefore time.Time `json:"not_before,omitempty"`
// NotAfter is the NotAfter field of the first certificate in the chain.
NotAfter time.Time `json:"not_after,omitempty"`
// WarningValidation is a validation warning message with the issue
// description.
WarningValidation string `json:"warning_validation,omitempty"`
// DNSNames is the value of SubjectAltNames field of the first certificate
// in the chain.
DNSNames []string `json:"dns_names"`
// ValidCert is true if the specified certificate chain is a valid chain of
// X509 certificates.
ValidCert bool `json:"valid_cert"`
// ValidChain is true if the specified certificate chain is verified and
// issued by a known CA.
ValidChain bool `json:"valid_chain"`
// ValidKey is true if the key is a valid private key.
ValidKey bool `json:"valid_key"`
// ValidPair is true if both certificate and private key are correct for
// each other.
ValidPair bool `json:"valid_pair"`
}
// tlsConfig is the TLS configuration and status response.
type tlsConfig struct {
*tlsConfigStatus `json:",inline"`
tlsConfigSettingsExt `json:",inline"`
}
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
// ensure that clients don't send and receive previously saved private keys.
type tlsConfigSettingsExt struct {
tlsConfigSettings `json:",inline"`
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
}
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
m.confLock.Lock()
data := tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{
tlsConfigSettings: m.conf,
},
tlsConfigStatus: m.status,
} }
m.confLock.Unlock()
marshalTLS(w, r, data) tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
}
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) return fmt.Errorf("reading cert file: %w", err)
return
} }
if setts.PrivateKeySaved { return nil
setts.PrivateKey = m.conf.PrivateKey
}
if setts.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !webCheckPortAvailable(setts.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
setts.PortHTTPS,
)
return
}
// Skip the error check, since we are only interested in the value of
// status.WarningValidation.
status := &tlsConfigStatus{}
_ = loadTLSConf(&setts.tlsConfigSettings, status)
resp := tlsConfig{
tlsConfigSettingsExt: setts,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
} }
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) { // loadPKey loads the private key from file, if necessary.
m.confLock.Lock() func loadPKey(tlsConf *tlsConfiguration) (err error) {
defer m.confLock.Unlock() if tlsConf.PrivateKey != "" {
return errors.Error("private key data and file cannot be set together")
// Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend.
//
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
log.Info("tls config has changed, restarting https server")
restartHTTPS = true
} else {
log.Info("tls: config has not changed")
} }
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
m.conf.Enabled = newConf.Enabled
m.conf.ServerName = newConf.ServerName
m.conf.ForceHTTPS = newConf.ForceHTTPS
m.conf.PortHTTPS = newConf.PortHTTPS
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
m.conf.CertificateChain = newConf.CertificateChain
m.conf.CertificatePath = newConf.CertificatePath
m.conf.CertificateChainData = newConf.CertificateChainData
m.conf.PrivateKey = newConf.PrivateKey
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.status = status
return restartHTTPS
}
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) return fmt.Errorf("reading key file: %w", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.conf.PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable https on it",
req.PortHTTPS,
)
return
}
status := &tlsConfigStatus{}
err = loadTLSConf(&req.tlsConfigSettings, status)
if err != nil {
resp := tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
return
}
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
m.setCertFileTime()
onConfigModified()
err = reconfigureDNSServer()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
resp := tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: m.status,
}
marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request. It is also should be done in a separate goroutine due to the
// same reason.
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
dnsPort, doqPort udpPort,
) (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(
tcpPorts,
tcpPort(bindPort),
tcpPort(betaBindPort),
tcpPort(dohPort),
tcpPort(dotPort),
tcpPort(dnscryptTCPPort),
)
err = tcpPorts.Validate()
if err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
}
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
err = udpPorts.Validate()
if err != nil {
return fmt.Errorf("validating udp ports: %w", err)
} }
return nil return nil
@@ -700,61 +439,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error)
return nil, "", errors.Error("tls: failed to parse private key") return nil, "", errors.Error("tls: failed to parse private key")
} }
// unmarshalTLS handles base64-encoded certificates transparently
func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
data := tlsConfigSettingsExt{}
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
return data, fmt.Errorf("failed to parse new TLS config json: %w", err)
}
if data.CertificateChain != "" {
var cert []byte
cert, err = base64.StdEncoding.DecodeString(data.CertificateChain)
if err != nil {
return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
}
data.CertificateChain = string(cert)
if data.CertificatePath != "" {
return data, fmt.Errorf("certificate data and file can't be set together")
}
}
if data.PrivateKey != "" {
var key []byte
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
if err != nil {
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
}
data.PrivateKey = string(key)
if data.PrivateKeyPath != "" {
return data, fmt.Errorf("private key data and file can't be set together")
}
}
return data, nil
}
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
if data.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
data.CertificateChain = encoded
}
if data.PrivateKey != "" {
data.PrivateKeySaved = true
data.PrivateKey = ""
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// registerWebHandlers registers HTTP handlers for TLS configuration.
func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
}

362
internal/home/tlshttp.go Normal file
View File

@@ -0,0 +1,362 @@
package home
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
"github.com/google/go-cmp/cmp"
)
// Encryption Settings HTTP API
// tlsConfigStatus contains the status of a certificate chain and key pair.
type tlsConfigStatus struct {
// Subject is the subject of the first certificate in the chain.
Subject string `json:"subject,omitempty"`
// Issuer is the issuer of the first certificate in the chain.
Issuer string `json:"issuer,omitempty"`
// KeyType is the type of the private key.
KeyType string `json:"key_type,omitempty"`
// NotBefore is the NotBefore field of the first certificate in the chain.
NotBefore time.Time `json:"not_before,omitempty"`
// NotAfter is the NotAfter field of the first certificate in the chain.
NotAfter time.Time `json:"not_after,omitempty"`
// WarningValidation is a validation warning message with the issue
// description.
WarningValidation string `json:"warning_validation,omitempty"`
// DNSNames is the value of SubjectAltNames field of the first certificate
// in the chain.
DNSNames []string `json:"dns_names"`
// ValidCert is true if the specified certificate chain is a valid chain of
// X509 certificates.
ValidCert bool `json:"valid_cert"`
// ValidChain is true if the specified certificate chain is verified and
// issued by a known CA.
ValidChain bool `json:"valid_chain"`
// ValidKey is true if the key is a valid private key.
ValidKey bool `json:"valid_key"`
// ValidPair is true if both certificate and private key are correct for
// each other.
ValidPair bool `json:"valid_pair"`
}
// tlsConfigResp is the TLS configuration and status response.
type tlsConfigResp struct {
*tlsConfigStatus
*tlsConfiguration
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
}
// tlsConfigReq is the TLS configuration request.
type tlsConfigReq struct {
tlsConfiguration
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
}
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
var resp *tlsConfigResp
func() {
m.mu.RLock()
defer m.mu.RUnlock()
resp = &tlsConfigResp{
tlsConfigStatus: m.status,
tlsConfiguration: m.conf.cloneForEncoding(),
}
}()
marshalTLS(w, r, resp)
}
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
req.PortHTTPS,
)
return
}
resp := &tlsConfigResp{
tlsConfigStatus: &tlsConfigStatus{},
tlsConfiguration: &req.tlsConfiguration,
}
// Skip the error check, since we are only interested in the value of
// resl.tlsConfigStatus.WarningValidation.
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
marshalTLS(w, r, resp)
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
dnsPort, doqPort udpPort,
) (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(
tcpPorts,
tcpPort(bindPort),
tcpPort(betaBindPort),
tcpPort(dohPort),
tcpPort(dotPort),
tcpPort(dnscryptTCPPort),
)
err = tcpPorts.Validate()
if err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
}
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
err = udpPorts.Validate()
if err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
return nil
}
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
// API.
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable https on it",
req.PortHTTPS,
)
return
}
resp := &tlsConfigResp{
tlsConfigStatus: &tlsConfigStatus{},
tlsConfiguration: &req.tlsConfiguration,
}
err = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
if err != nil {
marshalTLS(w, r, resp)
return
}
restartRequired := m.setConf(resp)
onConfigModified()
err = reconfigureDNSServer()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
resp.tlsConfiguration = m.confForEncoding()
marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request. It is also should be done in a separate goroutine due to the
// same reason.
if restartRequired {
go func() {
Context.web.TLSConfigChanged(context.Background(), resp.tlsConfiguration)
}()
}
}
// setConf sets the necessary values from the new configuration.
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
m.mu.Lock()
defer m.mu.Unlock()
// Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend.
//
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
log.Info("tls: config has changed, restarting https server")
restartRequired = true
} else {
log.Info("tls: config has not changed")
}
// Do not just write "m.conf = *newConf.tlsConfiguration", because all other
// members of m.conf must be preserved.
m.conf.Enabled = newConf.Enabled
m.conf.ServerName = newConf.ServerName
m.conf.ForceHTTPS = newConf.ForceHTTPS
m.conf.PortHTTPS = newConf.PortHTTPS
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
m.conf.CertificateChain = newConf.CertificateChain
m.conf.CertificatePath = newConf.CertificatePath
m.conf.CertificateChainData = newConf.CertificateChainData
m.conf.PrivateKey = newConf.PrivateKey
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.setCertFileTime()
m.status = newConf.tlsConfigStatus
return restartRequired
}
// marshalTLS handles Base64-encoded certificates transparently.
func marshalTLS(w http.ResponseWriter, r *http.Request, conf *tlsConfigResp) {
if conf.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(conf.CertificateChain))
conf.CertificateChain = encoded
}
if conf.PrivateKey != "" {
conf.PrivateKeySaved = true
conf.PrivateKey = ""
}
_ = aghhttp.WriteJSONResponse(w, r, conf)
}
// unmarshalTLS handles Base64-encoded certificates transparently.
func unmarshalTLS(r *http.Request) (req *tlsConfigReq, err error) {
req = &tlsConfigReq{}
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
return nil, fmt.Errorf("parsing tls config: %w", err)
}
if req.CertificateChain != "" {
var cert []byte
cert, err = base64.StdEncoding.DecodeString(req.CertificateChain)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
}
req.CertificateChain = string(cert)
if req.CertificatePath != "" {
return nil, fmt.Errorf("certificate data and file can't be set together")
}
}
if req.PrivateKey != "" {
var key []byte
key, err = base64.StdEncoding.DecodeString(req.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode private key: %w", err)
}
req.PrivateKey = string(key)
if req.PrivateKeyPath != "" {
return nil, fmt.Errorf("private key data and file can't be set together")
}
}
return req, nil
}
// registerWebHandlers registers HTTP handlers for TLS configuration.
func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
}

View File

@@ -143,7 +143,7 @@ func webCheckPortAvailable(port int) (ok bool) {
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary. // if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf *tlsConfiguration) {
log.Debug("web: applying new tls configuration") log.Debug("web: applying new tls configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0) web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)