Compare commits
3 Commits
master
...
4927-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c7d56dca3 | ||
|
|
f36efa26a4 | ||
|
|
a8850059db |
@@ -145,7 +145,8 @@ type FilteringConfig struct {
|
|||||||
IpsetListFileName string `yaml:"ipset_file"`
|
IpsetListFileName string `yaml:"ipset_file"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, DNS-over-TLS,
|
||||||
|
// and DNS-over-QUIC.
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
cert tls.Certificate
|
cert tls.Certificate
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"github.com/google/renameio/maybe"
|
"github.com/google/renameio/maybe"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
yaml "gopkg.in/yaml.v3"
|
yaml "gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -113,8 +114,8 @@ type configuration struct {
|
|||||||
// An active session is automatically refreshed once a day.
|
// An active session is automatically refreshed once a day.
|
||||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||||
|
|
||||||
DNS dnsConfig `yaml:"dns"`
|
DNS dnsConfig `yaml:"dns"`
|
||||||
TLS tlsConfigSettings `yaml:"tls"`
|
TLS tlsConfiguration `yaml:"tls"`
|
||||||
|
|
||||||
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
||||||
// config used in the filtering module at the startup. Afterwards it's
|
// config used in the filtering module at the startup. Afterwards it's
|
||||||
@@ -199,7 +200,8 @@ type dnsConfig struct {
|
|||||||
UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
|
UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsConfigSettings struct {
|
// tlsConfiguration is the on-disk TLS configuration.
|
||||||
|
type tlsConfiguration struct {
|
||||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||||
@@ -223,6 +225,29 @@ type tlsConfigSettings struct {
|
|||||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cloneForEncoding returns a clone of c with all top-level fields of c and all
|
||||||
|
// exported and YAML-encoded fields of c.TLSConfig cloned.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): This is better than races, but still not good enough.
|
||||||
|
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
v := *c
|
||||||
|
cloned = &v
|
||||||
|
cloned.TLSConfig = dnsforward.TLSConfig{
|
||||||
|
CertificateChain: c.CertificateChain,
|
||||||
|
PrivateKey: c.PrivateKey,
|
||||||
|
CertificatePath: c.CertificatePath,
|
||||||
|
PrivateKeyPath: c.PrivateKeyPath,
|
||||||
|
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
|
||||||
|
StrictSNICheck: c.StrictSNICheck,
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
// config is the global configuration structure.
|
// config is the global configuration structure.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||||
@@ -273,7 +298,7 @@ var config = &configuration{
|
|||||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
},
|
},
|
||||||
TLS: tlsConfigSettings{
|
TLS: tlsConfiguration{
|
||||||
PortHTTPS: defaultPortHTTPS,
|
PortHTTPS: defaultPortHTTPS,
|
||||||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||||
PortDNSOverQUIC: defaultPortQUIC,
|
PortDNSOverQUIC: defaultPortQUIC,
|
||||||
@@ -442,7 +467,7 @@ func (c *configuration) write() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if Context.tls != nil {
|
if Context.tls != nil {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfiguration{}
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
config.TLS = tlsConf
|
config.TLS = tlsConf
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConf := &tlsConfigSettings{}
|
tlsConf := &tlsConfiguration{}
|
||||||
Context.tls.WriteDiskConfig(tlsConf)
|
Context.tls.WriteDiskConfig(tlsConf)
|
||||||
|
|
||||||
canUpdate := true
|
canUpdate := true
|
||||||
@@ -172,7 +172,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
|||||||
|
|
||||||
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
||||||
// indicates that privileged ports are used.
|
// indicates that privileged ports are used.
|
||||||
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
func tlsConfUsesPrivilegedPorts(c *tlsConfiguration) (ok bool) {
|
||||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
|||||||
OnDNSRequest: onDNSRequest,
|
OnDNSRequest: onDNSRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfiguration{}
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
if tlsConf.Enabled {
|
if tlsConf.Enabled {
|
||||||
newConf.TLSConfig = tlsConf.TLSConfig
|
newConf.TLSConfig = tlsConf.TLSConfig
|
||||||
@@ -250,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
|||||||
return newConf, nil
|
return newConf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
|
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfiguration) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||||
if tlsConf.DNSCryptConfigFile == "" {
|
if tlsConf.DNSCryptConfigFile == "" {
|
||||||
return dnscc, errors.Error("no dnscrypt_config_file")
|
return dnscc, errors.Error("no dnscrypt_config_file")
|
||||||
}
|
}
|
||||||
@@ -288,7 +288,7 @@ type dnsEncryption struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getDNSEncryption() (de dnsEncryption) {
|
func getDNSEncryption() (de dnsEncryption) {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfiguration{}
|
||||||
|
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
|
|
||||||
|
|||||||
@@ -512,7 +512,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
|||||||
}
|
}
|
||||||
config.Users = nil
|
config.Users = nil
|
||||||
|
|
||||||
Context.tls, err = newTLSManager(config.TLS)
|
Context.tls, err = newTLSManager(&config.TLS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("initializing tls: %s", err)
|
log.Fatalf("initializing tls: %s", err)
|
||||||
}
|
}
|
||||||
@@ -817,7 +817,7 @@ func printWebAddrs(proto, addr string, port, betaPort int) {
|
|||||||
// printHTTPAddresses prints the IP addresses which user can use to access the
|
// printHTTPAddresses prints the IP addresses which user can use to access the
|
||||||
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
||||||
func printHTTPAddresses(proto string) {
|
func printHTTPAddresses(proto string) {
|
||||||
tlsConf := tlsConfigSettings{}
|
tlsConf := tlsConfiguration{}
|
||||||
if Context.tls != nil {
|
if Context.tls != nil {
|
||||||
Context.tls.WriteDiskConfig(&tlsConf)
|
Context.tls.WriteDiskConfig(&tlsConf)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,11 @@ func setupDNSIPs(t testing.TB) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.tls = &tlsManager{}
|
var err error
|
||||||
|
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleMobileConfigDoH(t *testing.T) {
|
func TestHandleMobileConfigDoH(t *testing.T) {
|
||||||
@@ -65,7 +69,11 @@ func TestHandleMobileConfigDoH(t *testing.T) {
|
|||||||
oldTLSConf := Context.tls
|
oldTLSConf := Context.tls
|
||||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||||
|
|
||||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
var err error
|
||||||
|
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -137,7 +145,11 @@ func TestHandleMobileConfigDoT(t *testing.T) {
|
|||||||
oldTLSConf := Context.tls
|
oldTLSConf := Context.tls
|
||||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||||
|
|
||||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
var err error
|
||||||
|
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||||
|
Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -8,42 +8,39 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||||
// encryption.
|
// encryption.
|
||||||
type tlsManager struct {
|
type tlsManager struct {
|
||||||
// status is the current status of the configuration. It is never nil.
|
// mu protects all fields.
|
||||||
status *tlsConfigStatus
|
mu *sync.RWMutex
|
||||||
|
|
||||||
// certLastMod is the last modification time of the certificate file.
|
// certLastMod is the last modification time of the certificate file.
|
||||||
certLastMod time.Time
|
certLastMod time.Time
|
||||||
|
|
||||||
confLock sync.Mutex
|
// status is the current status of the configuration. It is never nil.
|
||||||
conf tlsConfigSettings
|
status *tlsConfigStatus
|
||||||
|
|
||||||
|
// conf is the current TLS configuration.
|
||||||
|
conf *tlsConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTLSManager initializes the TLS configuration.
|
// newTLSManager initializes the TLS configuration.
|
||||||
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
|
||||||
m = &tlsManager{
|
m = &tlsManager{
|
||||||
status: &tlsConfigStatus{},
|
status: &tlsConfigStatus{},
|
||||||
|
mu: &sync.RWMutex{},
|
||||||
conf: conf,
|
conf: conf,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// confForEncoding returns a partial clone of the current TLS configuration. It
|
||||||
|
// is safe for concurrent use.
|
||||||
|
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.conf.cloneForEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
// load reloads the TLS configuration from files or data from the config file.
|
// load reloads the TLS configuration from files or data from the config file.
|
||||||
|
// m.mu is expected to be locked for writing.
|
||||||
func (m *tlsManager) load() (err error) {
|
func (m *tlsManager) load() (err error) {
|
||||||
err = loadTLSConf(&m.conf, m.status)
|
err = loadTLSConf(m.conf, m.status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("loading config: %w", err)
|
return fmt.Errorf("loading config: %w", err)
|
||||||
}
|
}
|
||||||
@@ -70,14 +77,12 @@ func (m *tlsManager) load() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WriteDiskConfig - write config
|
// WriteDiskConfig - write config
|
||||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
|
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
|
||||||
m.confLock.Lock()
|
*conf = *m.confForEncoding()
|
||||||
*conf = m.conf
|
|
||||||
m.confLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
||||||
// errors, setCertFileTime logs them.
|
// errors, setCertFileTime logs them. mu is expected to be locked for writing.
|
||||||
func (m *tlsManager) setCertFileTime() {
|
func (m *tlsManager) setCertFileTime() {
|
||||||
if len(m.conf.CertificatePath) == 0 {
|
if len(m.conf.CertificatePath) == 0 {
|
||||||
return
|
return
|
||||||
@@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() {
|
|||||||
func (m *tlsManager) start() {
|
func (m *tlsManager) start() {
|
||||||
m.registerWebHandlers()
|
m.registerWebHandlers()
|
||||||
|
|
||||||
m.confLock.Lock()
|
|
||||||
tlsConf := m.conf
|
|
||||||
m.confLock.Unlock()
|
|
||||||
|
|
||||||
// The background context is used because the TLSConfigChanged wraps context
|
// The background context is used because the TLSConfigChanged wraps context
|
||||||
// with timeout on its own and shuts down the server, which handles current
|
// with timeout on its own and shuts down the server, which handles current
|
||||||
// request.
|
// request.
|
||||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
|
||||||
}
|
}
|
||||||
|
|
||||||
// reload updates the configuration and restarts t.
|
// reload updates the configuration and restarts m.
|
||||||
func (m *tlsManager) reload() {
|
func (m *tlsManager) reload() {
|
||||||
m.confLock.Lock()
|
m.mu.Lock()
|
||||||
tlsConf := m.conf
|
defer m.mu.Unlock()
|
||||||
m.confLock.Unlock()
|
|
||||||
|
|
||||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
fi, err := os.Stat(m.conf.CertificatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("tls: %s", err)
|
log.Error("tls: %s", err)
|
||||||
|
|
||||||
@@ -132,9 +132,7 @@ func (m *tlsManager) reload() {
|
|||||||
|
|
||||||
log.Debug("tls: certificate file is modified")
|
log.Debug("tls: certificate file is modified")
|
||||||
|
|
||||||
m.confLock.Lock()
|
|
||||||
err = m.load()
|
err = m.load()
|
||||||
m.confLock.Unlock()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("tls: reloading: %s", err)
|
log.Error("tls: reloading: %s", err)
|
||||||
|
|
||||||
@@ -145,19 +143,15 @@ func (m *tlsManager) reload() {
|
|||||||
|
|
||||||
_ = reconfigureDNSServer()
|
_ = reconfigureDNSServer()
|
||||||
|
|
||||||
m.confLock.Lock()
|
|
||||||
tlsConf = m.conf
|
|
||||||
m.confLock.Unlock()
|
|
||||||
|
|
||||||
// The background context is used because the TLSConfigChanged wraps context
|
// The background context is used because the TLSConfigChanged wraps context
|
||||||
// with timeout on its own and shuts down the server, which handles current
|
// with timeout on its own and shuts down the server, which handles current
|
||||||
// request.
|
// request.
|
||||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
Context.web.TLSConfigChanged(context.Background(), m.conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||||
// also set in status.WarningValidation.
|
// also set in status.WarningValidation.
|
||||||
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
|
func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status.WarningValidation = err.Error()
|
status.WarningValidation = err.Error()
|
||||||
@@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||||||
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
||||||
|
|
||||||
if tlsConf.CertificatePath != "" {
|
if tlsConf.CertificatePath != "" {
|
||||||
if tlsConf.CertificateChain != "" {
|
err = loadCert(tlsConf)
|
||||||
return errors.Error("certificate data and file can't be set together")
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading cert file: %w", err)
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set status.ValidCert to true to signal the frontend that the
|
// Set status.ValidCert to true to signal the frontend that the
|
||||||
@@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tlsConf.PrivateKeyPath != "" {
|
if tlsConf.PrivateKeyPath != "" {
|
||||||
if tlsConf.PrivateKey != "" {
|
err = loadPKey(tlsConf)
|
||||||
return errors.Error("private key data and file can't be set together")
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reading key file: %w", err)
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
status.ValidKey = true
|
status.ValidKey = true
|
||||||
@@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
// loadCert loads the certificate from file, if necessary.
|
||||||
type tlsConfigStatus struct {
|
func loadCert(tlsConf *tlsConfiguration) (err error) {
|
||||||
// Subject is the subject of the first certificate in the chain.
|
if tlsConf.CertificateChain != "" {
|
||||||
Subject string `json:"subject,omitempty"`
|
return errors.Error("certificate data and file can't be set together")
|
||||||
|
|
||||||
// Issuer is the issuer of the first certificate in the chain.
|
|
||||||
Issuer string `json:"issuer,omitempty"`
|
|
||||||
|
|
||||||
// KeyType is the type of the private key.
|
|
||||||
KeyType string `json:"key_type,omitempty"`
|
|
||||||
|
|
||||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
|
||||||
NotBefore time.Time `json:"not_before,omitempty"`
|
|
||||||
|
|
||||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
|
||||||
NotAfter time.Time `json:"not_after,omitempty"`
|
|
||||||
|
|
||||||
// WarningValidation is a validation warning message with the issue
|
|
||||||
// description.
|
|
||||||
WarningValidation string `json:"warning_validation,omitempty"`
|
|
||||||
|
|
||||||
// DNSNames is the value of SubjectAltNames field of the first certificate
|
|
||||||
// in the chain.
|
|
||||||
DNSNames []string `json:"dns_names"`
|
|
||||||
|
|
||||||
// ValidCert is true if the specified certificate chain is a valid chain of
|
|
||||||
// X509 certificates.
|
|
||||||
ValidCert bool `json:"valid_cert"`
|
|
||||||
|
|
||||||
// ValidChain is true if the specified certificate chain is verified and
|
|
||||||
// issued by a known CA.
|
|
||||||
ValidChain bool `json:"valid_chain"`
|
|
||||||
|
|
||||||
// ValidKey is true if the key is a valid private key.
|
|
||||||
ValidKey bool `json:"valid_key"`
|
|
||||||
|
|
||||||
// ValidPair is true if both certificate and private key are correct for
|
|
||||||
// each other.
|
|
||||||
ValidPair bool `json:"valid_pair"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// tlsConfig is the TLS configuration and status response.
|
|
||||||
type tlsConfig struct {
|
|
||||||
*tlsConfigStatus `json:",inline"`
|
|
||||||
tlsConfigSettingsExt `json:",inline"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
|
|
||||||
// ensure that clients don't send and receive previously saved private keys.
|
|
||||||
type tlsConfigSettingsExt struct {
|
|
||||||
tlsConfigSettings `json:",inline"`
|
|
||||||
|
|
||||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
|
||||||
// key from answer.
|
|
||||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
|
||||||
m.confLock.Lock()
|
|
||||||
data := tlsConfig{
|
|
||||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
|
||||||
tlsConfigSettings: m.conf,
|
|
||||||
},
|
|
||||||
tlsConfigStatus: m.status,
|
|
||||||
}
|
}
|
||||||
m.confLock.Unlock()
|
|
||||||
|
|
||||||
marshalTLS(w, r, data)
|
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||||
}
|
|
||||||
|
|
||||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|
||||||
setts, err := unmarshalTLS(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
return fmt.Errorf("reading cert file: %w", err)
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if setts.PrivateKeySaved {
|
return nil
|
||||||
setts.PrivateKey = m.conf.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if setts.Enabled {
|
|
||||||
err = validatePorts(
|
|
||||||
tcpPort(config.BindPort),
|
|
||||||
tcpPort(config.BetaBindPort),
|
|
||||||
tcpPort(setts.PortHTTPS),
|
|
||||||
tcpPort(setts.PortDNSOverTLS),
|
|
||||||
tcpPort(setts.PortDNSCrypt),
|
|
||||||
udpPort(config.DNS.Port),
|
|
||||||
udpPort(setts.PortDNSOverQUIC),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
|
||||||
aghhttp.Error(
|
|
||||||
r,
|
|
||||||
w,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
"port %d is not available, cannot enable HTTPS on it",
|
|
||||||
setts.PortHTTPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip the error check, since we are only interested in the value of
|
|
||||||
// status.WarningValidation.
|
|
||||||
status := &tlsConfigStatus{}
|
|
||||||
_ = loadTLSConf(&setts.tlsConfigSettings, status)
|
|
||||||
resp := tlsConfig{
|
|
||||||
tlsConfigSettingsExt: setts,
|
|
||||||
tlsConfigStatus: status,
|
|
||||||
}
|
|
||||||
|
|
||||||
marshalTLS(w, r, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
|
// loadPKey loads the private key from file, if necessary.
|
||||||
m.confLock.Lock()
|
func loadPKey(tlsConf *tlsConfiguration) (err error) {
|
||||||
defer m.confLock.Unlock()
|
if tlsConf.PrivateKey != "" {
|
||||||
|
return errors.Error("private key data and file cannot be set together")
|
||||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
|
||||||
// accept these from the frontend.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
|
||||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
|
||||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
|
||||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
|
||||||
log.Info("tls config has changed, restarting https server")
|
|
||||||
restartHTTPS = true
|
|
||||||
} else {
|
|
||||||
log.Info("tls: config has not changed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||||
m.conf.Enabled = newConf.Enabled
|
|
||||||
m.conf.ServerName = newConf.ServerName
|
|
||||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
|
||||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
|
||||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
|
||||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
|
||||||
m.conf.CertificateChain = newConf.CertificateChain
|
|
||||||
m.conf.CertificatePath = newConf.CertificatePath
|
|
||||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
|
||||||
m.conf.PrivateKey = newConf.PrivateKey
|
|
||||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
|
||||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
|
||||||
m.status = status
|
|
||||||
|
|
||||||
return restartHTTPS
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|
||||||
req, err := unmarshalTLS(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
return fmt.Errorf("reading key file: %w", err)
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.PrivateKeySaved {
|
|
||||||
req.PrivateKey = m.conf.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Enabled {
|
|
||||||
err = validatePorts(
|
|
||||||
tcpPort(config.BindPort),
|
|
||||||
tcpPort(config.BetaBindPort),
|
|
||||||
tcpPort(req.PortHTTPS),
|
|
||||||
tcpPort(req.PortDNSOverTLS),
|
|
||||||
tcpPort(req.PortDNSCrypt),
|
|
||||||
udpPort(config.DNS.Port),
|
|
||||||
udpPort(req.PortDNSOverQUIC),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
|
||||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
|
||||||
aghhttp.Error(
|
|
||||||
r,
|
|
||||||
w,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
"port %d is not available, cannot enable https on it",
|
|
||||||
req.PortHTTPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
status := &tlsConfigStatus{}
|
|
||||||
err = loadTLSConf(&req.tlsConfigSettings, status)
|
|
||||||
if err != nil {
|
|
||||||
resp := tlsConfig{
|
|
||||||
tlsConfigSettingsExt: req,
|
|
||||||
tlsConfigStatus: status,
|
|
||||||
}
|
|
||||||
|
|
||||||
marshalTLS(w, r, resp)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
|
|
||||||
m.setCertFileTime()
|
|
||||||
onConfigModified()
|
|
||||||
|
|
||||||
err = reconfigureDNSServer()
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := tlsConfig{
|
|
||||||
tlsConfigSettingsExt: req,
|
|
||||||
tlsConfigStatus: m.status,
|
|
||||||
}
|
|
||||||
|
|
||||||
marshalTLS(w, r, resp)
|
|
||||||
if f, ok := w.(http.Flusher); ok {
|
|
||||||
f.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// The background context is used because the TLSConfigChanged wraps context
|
|
||||||
// with timeout on its own and shuts down the server, which handles current
|
|
||||||
// request. It is also should be done in a separate goroutine due to the
|
|
||||||
// same reason.
|
|
||||||
if restartHTTPS {
|
|
||||||
go func() {
|
|
||||||
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
|
||||||
// DNS protocols.
|
|
||||||
func validatePorts(
|
|
||||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
|
||||||
dnsPort, doqPort udpPort,
|
|
||||||
) (err error) {
|
|
||||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
|
||||||
addPorts(
|
|
||||||
tcpPorts,
|
|
||||||
tcpPort(bindPort),
|
|
||||||
tcpPort(betaBindPort),
|
|
||||||
tcpPort(dohPort),
|
|
||||||
tcpPort(dotPort),
|
|
||||||
tcpPort(dnscryptTCPPort),
|
|
||||||
)
|
|
||||||
|
|
||||||
err = tcpPorts.Validate()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("validating tcp ports: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
|
||||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
|
||||||
|
|
||||||
err = udpPorts.Validate()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("validating udp ports: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -700,61 +439,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error)
|
|||||||
|
|
||||||
return nil, "", errors.Error("tls: failed to parse private key")
|
return nil, "", errors.Error("tls: failed to parse private key")
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalTLS handles base64-encoded certificates transparently
|
|
||||||
func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
|
|
||||||
data := tlsConfigSettingsExt{}
|
|
||||||
err := json.NewDecoder(r.Body).Decode(&data)
|
|
||||||
if err != nil {
|
|
||||||
return data, fmt.Errorf("failed to parse new TLS config json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.CertificateChain != "" {
|
|
||||||
var cert []byte
|
|
||||||
cert, err = base64.StdEncoding.DecodeString(data.CertificateChain)
|
|
||||||
if err != nil {
|
|
||||||
return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data.CertificateChain = string(cert)
|
|
||||||
if data.CertificatePath != "" {
|
|
||||||
return data, fmt.Errorf("certificate data and file can't be set together")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.PrivateKey != "" {
|
|
||||||
var key []byte
|
|
||||||
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data.PrivateKey = string(key)
|
|
||||||
if data.PrivateKeyPath != "" {
|
|
||||||
return data, fmt.Errorf("private key data and file can't be set together")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
|
||||||
if data.CertificateChain != "" {
|
|
||||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
|
||||||
data.CertificateChain = encoded
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.PrivateKey != "" {
|
|
||||||
data.PrivateKeySaved = true
|
|
||||||
data.PrivateKey = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
|
||||||
func (m *tlsManager) registerWebHandlers() {
|
|
||||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
|
||||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
|
||||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
|
||||||
}
|
|
||||||
|
|||||||
362
internal/home/tlshttp.go
Normal file
362
internal/home/tlshttp.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encryption Settings HTTP API
|
||||||
|
|
||||||
|
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
||||||
|
type tlsConfigStatus struct {
|
||||||
|
// Subject is the subject of the first certificate in the chain.
|
||||||
|
Subject string `json:"subject,omitempty"`
|
||||||
|
|
||||||
|
// Issuer is the issuer of the first certificate in the chain.
|
||||||
|
Issuer string `json:"issuer,omitempty"`
|
||||||
|
|
||||||
|
// KeyType is the type of the private key.
|
||||||
|
KeyType string `json:"key_type,omitempty"`
|
||||||
|
|
||||||
|
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||||
|
NotBefore time.Time `json:"not_before,omitempty"`
|
||||||
|
|
||||||
|
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||||
|
NotAfter time.Time `json:"not_after,omitempty"`
|
||||||
|
|
||||||
|
// WarningValidation is a validation warning message with the issue
|
||||||
|
// description.
|
||||||
|
WarningValidation string `json:"warning_validation,omitempty"`
|
||||||
|
|
||||||
|
// DNSNames is the value of SubjectAltNames field of the first certificate
|
||||||
|
// in the chain.
|
||||||
|
DNSNames []string `json:"dns_names"`
|
||||||
|
|
||||||
|
// ValidCert is true if the specified certificate chain is a valid chain of
|
||||||
|
// X509 certificates.
|
||||||
|
ValidCert bool `json:"valid_cert"`
|
||||||
|
|
||||||
|
// ValidChain is true if the specified certificate chain is verified and
|
||||||
|
// issued by a known CA.
|
||||||
|
ValidChain bool `json:"valid_chain"`
|
||||||
|
|
||||||
|
// ValidKey is true if the key is a valid private key.
|
||||||
|
ValidKey bool `json:"valid_key"`
|
||||||
|
|
||||||
|
// ValidPair is true if both certificate and private key are correct for
|
||||||
|
// each other.
|
||||||
|
ValidPair bool `json:"valid_pair"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsConfigResp is the TLS configuration and status response.
|
||||||
|
type tlsConfigResp struct {
|
||||||
|
*tlsConfigStatus
|
||||||
|
*tlsConfiguration
|
||||||
|
|
||||||
|
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||||
|
// key from answer.
|
||||||
|
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsConfigReq is the TLS configuration request.
|
||||||
|
type tlsConfigReq struct {
|
||||||
|
tlsConfiguration
|
||||||
|
|
||||||
|
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||||
|
// key from answer.
|
||||||
|
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
|
||||||
|
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var resp *tlsConfigResp
|
||||||
|
func() {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
resp = &tlsConfigResp{
|
||||||
|
tlsConfigStatus: m.status,
|
||||||
|
tlsConfiguration: m.conf.cloneForEncoding(),
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
marshalTLS(w, r, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||||
|
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
req, err := unmarshalTLS(r)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.PrivateKeySaved {
|
||||||
|
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Enabled {
|
||||||
|
err = validatePorts(
|
||||||
|
tcpPort(config.BindPort),
|
||||||
|
tcpPort(config.BetaBindPort),
|
||||||
|
tcpPort(req.PortHTTPS),
|
||||||
|
tcpPort(req.PortDNSOverTLS),
|
||||||
|
tcpPort(req.PortDNSCrypt),
|
||||||
|
udpPort(config.DNS.Port),
|
||||||
|
udpPort(req.PortDNSOverQUIC),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||||
|
aghhttp.Error(
|
||||||
|
r,
|
||||||
|
w,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"port %d is not available, cannot enable HTTPS on it",
|
||||||
|
req.PortHTTPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &tlsConfigResp{
|
||||||
|
tlsConfigStatus: &tlsConfigStatus{},
|
||||||
|
tlsConfiguration: &req.tlsConfiguration,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the error check, since we are only interested in the value of
|
||||||
|
// resl.tlsConfigStatus.WarningValidation.
|
||||||
|
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
|
||||||
|
|
||||||
|
marshalTLS(w, r, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||||
|
// DNS protocols.
|
||||||
|
func validatePorts(
|
||||||
|
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||||
|
dnsPort, doqPort udpPort,
|
||||||
|
) (err error) {
|
||||||
|
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||||
|
addPorts(
|
||||||
|
tcpPorts,
|
||||||
|
tcpPort(bindPort),
|
||||||
|
tcpPort(betaBindPort),
|
||||||
|
tcpPort(dohPort),
|
||||||
|
tcpPort(dotPort),
|
||||||
|
tcpPort(dnscryptTCPPort),
|
||||||
|
)
|
||||||
|
|
||||||
|
err = tcpPorts.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating tcp ports: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||||
|
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||||
|
|
||||||
|
err = udpPorts.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating udp ports: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
|
||||||
|
// API.
|
||||||
|
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||||
|
req, err := unmarshalTLS(r)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.PrivateKeySaved {
|
||||||
|
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Enabled {
|
||||||
|
err = validatePorts(
|
||||||
|
tcpPort(config.BindPort),
|
||||||
|
tcpPort(config.BetaBindPort),
|
||||||
|
tcpPort(req.PortHTTPS),
|
||||||
|
tcpPort(req.PortDNSOverTLS),
|
||||||
|
tcpPort(req.PortDNSCrypt),
|
||||||
|
udpPort(config.DNS.Port),
|
||||||
|
udpPort(req.PortDNSOverQUIC),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||||
|
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||||
|
aghhttp.Error(
|
||||||
|
r,
|
||||||
|
w,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"port %d is not available, cannot enable https on it",
|
||||||
|
req.PortHTTPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &tlsConfigResp{
|
||||||
|
tlsConfigStatus: &tlsConfigStatus{},
|
||||||
|
tlsConfiguration: &req.tlsConfiguration,
|
||||||
|
}
|
||||||
|
err = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
|
||||||
|
if err != nil {
|
||||||
|
marshalTLS(w, r, resp)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
restartRequired := m.setConf(resp)
|
||||||
|
onConfigModified()
|
||||||
|
|
||||||
|
err = reconfigureDNSServer()
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.tlsConfiguration = m.confForEncoding()
|
||||||
|
marshalTLS(w, r, resp)
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// The background context is used because the TLSConfigChanged wraps context
|
||||||
|
// with timeout on its own and shuts down the server, which handles current
|
||||||
|
// request. It is also should be done in a separate goroutine due to the
|
||||||
|
// same reason.
|
||||||
|
if restartRequired {
|
||||||
|
go func() {
|
||||||
|
Context.web.TLSConfigChanged(context.Background(), resp.tlsConfiguration)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setConf sets the necessary values from the new configuration.
|
||||||
|
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||||
|
// accept these from the frontend.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||||
|
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||||
|
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||||
|
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||||
|
log.Info("tls: config has changed, restarting https server")
|
||||||
|
restartRequired = true
|
||||||
|
} else {
|
||||||
|
log.Info("tls: config has not changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not just write "m.conf = *newConf.tlsConfiguration", because all other
|
||||||
|
// members of m.conf must be preserved.
|
||||||
|
m.conf.Enabled = newConf.Enabled
|
||||||
|
m.conf.ServerName = newConf.ServerName
|
||||||
|
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||||
|
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||||
|
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||||
|
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||||
|
|
||||||
|
m.conf.CertificateChain = newConf.CertificateChain
|
||||||
|
m.conf.CertificatePath = newConf.CertificatePath
|
||||||
|
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||||
|
m.conf.PrivateKey = newConf.PrivateKey
|
||||||
|
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||||
|
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||||
|
|
||||||
|
m.setCertFileTime()
|
||||||
|
|
||||||
|
m.status = newConf.tlsConfigStatus
|
||||||
|
|
||||||
|
return restartRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalTLS handles Base64-encoded certificates transparently.
|
||||||
|
func marshalTLS(w http.ResponseWriter, r *http.Request, conf *tlsConfigResp) {
|
||||||
|
if conf.CertificateChain != "" {
|
||||||
|
encoded := base64.StdEncoding.EncodeToString([]byte(conf.CertificateChain))
|
||||||
|
conf.CertificateChain = encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.PrivateKey != "" {
|
||||||
|
conf.PrivateKeySaved = true
|
||||||
|
conf.PrivateKey = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = aghhttp.WriteJSONResponse(w, r, conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalTLS handles Base64-encoded certificates transparently.
|
||||||
|
func unmarshalTLS(r *http.Request) (req *tlsConfigReq, err error) {
|
||||||
|
req = &tlsConfigReq{}
|
||||||
|
err = json.NewDecoder(r.Body).Decode(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing tls config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.CertificateChain != "" {
|
||||||
|
var cert []byte
|
||||||
|
cert, err = base64.StdEncoding.DecodeString(req.CertificateChain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.CertificateChain = string(cert)
|
||||||
|
if req.CertificatePath != "" {
|
||||||
|
return nil, fmt.Errorf("certificate data and file can't be set together")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.PrivateKey != "" {
|
||||||
|
var key []byte
|
||||||
|
key, err = base64.StdEncoding.DecodeString(req.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.PrivateKey = string(key)
|
||||||
|
if req.PrivateKeyPath != "" {
|
||||||
|
return nil, fmt.Errorf("private key data and file can't be set together")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||||
|
func (m *tlsManager) registerWebHandlers() {
|
||||||
|
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||||
|
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||||
|
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||||
|
}
|
||||||
@@ -143,7 +143,7 @@ func webCheckPortAvailable(port int) (ok bool) {
|
|||||||
|
|
||||||
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||||
// if necessary.
|
// if necessary.
|
||||||
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf *tlsConfiguration) {
|
||||||
log.Debug("web: applying new tls configuration")
|
log.Debug("web: applying new tls configuration")
|
||||||
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||||
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user