all: sync with master
This commit is contained in:
58
internal/aghuser/aghuser.go
Normal file
58
internal/aghuser/aghuser.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package aghuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Login is the type for web user logins.
|
||||
type Login string
|
||||
|
||||
// NewLogin returns a web user login.
|
||||
//
|
||||
// TODO(s.chzhen): Add more constraints as needed.
|
||||
func NewLogin(s string) (l Login, err error) {
|
||||
if s == "" {
|
||||
return "", errors.ErrEmptyValue
|
||||
}
|
||||
|
||||
return Login(s), nil
|
||||
}
|
||||
|
||||
// Password is an interface that defines methods for handling web user
|
||||
// passwords.
|
||||
type Password interface {
|
||||
// Authenticate returns true if the provided password is allowed.
|
||||
Authenticate(ctx context.Context, password string) (ok bool)
|
||||
|
||||
// Hash returns a hashed representation of the web user password.
|
||||
Hash() (b []byte)
|
||||
}
|
||||
|
||||
// DefaultPassword is the default bcrypt implementation of the [Password]
|
||||
// interface.
|
||||
type DefaultPassword struct {
|
||||
hash []byte
|
||||
}
|
||||
|
||||
// NewDefaultPassword returns the new properly initialized *DefaultPassword.
|
||||
func NewDefaultPassword(hash string) (p *DefaultPassword) {
|
||||
return &DefaultPassword{
|
||||
hash: []byte(hash),
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Password = (*DefaultPassword)(nil)
|
||||
|
||||
// Authenticate implements the [Password] interface for *DefaultPassword.
|
||||
func (p *DefaultPassword) Authenticate(ctx context.Context, passwd string) (ok bool) {
|
||||
return bcrypt.CompareHashAndPassword([]byte(p.hash), []byte(passwd)) == nil
|
||||
}
|
||||
|
||||
// Hash implements the [Password] interface for *DefaultPassword.
|
||||
func (p *DefaultPassword) Hash() (b []byte) {
|
||||
return p.hash
|
||||
}
|
||||
6
internal/aghuser/aghuser_test.go
Normal file
6
internal/aghuser/aghuser_test.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package aghuser_test
|
||||
|
||||
import "time"
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
149
internal/aghuser/db.go
Normal file
149
internal/aghuser/db.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package aghuser
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// DB is an interface that defines methods for interacting with user
|
||||
// information. All methods must be safe for concurrent use.
|
||||
//
|
||||
// TODO(s.chzhen): Use this.
|
||||
//
|
||||
// TODO(s.chzhen): Consider updating methods to return a clone.
|
||||
type DB interface {
|
||||
// All retrieves all users from the database, sorted by login.
|
||||
//
|
||||
// TODO(s.chzhen): Consider function signature change to reflect the
|
||||
// in-memory implementation, as it currently always returns nil for error.
|
||||
All(ctx context.Context) (users []*User, err error)
|
||||
|
||||
// ByLogin retrieves a user by their login. u must not be modified.
|
||||
//
|
||||
// TODO(s.chzhen): Remove this once user sessions support [UserID].
|
||||
ByLogin(ctx context.Context, login Login) (u *User, err error)
|
||||
|
||||
// ByUUID retrieves a user by their unique identifier. u must not be
|
||||
// modified.
|
||||
//
|
||||
// TODO(s.chzhen): Use this.
|
||||
ByUUID(ctx context.Context, id UserID) (u *User, err error)
|
||||
|
||||
// Create adds a new user to the database. If the credentials already
|
||||
// exist, it returns the [errors.ErrDuplicated] error. It also can return
|
||||
// an error from the cryptographic randomness reader. u must not be
|
||||
// modified.
|
||||
Create(ctx context.Context, u *User) (err error)
|
||||
}
|
||||
|
||||
// DefaultDB is the default in-memory implementation of the [DB] interface.
|
||||
type DefaultDB struct {
|
||||
// mu protects all properties below.
|
||||
mu *sync.Mutex
|
||||
|
||||
// loginToUserID maps a web user login to their UserID. The values must not
|
||||
// be empty.
|
||||
//
|
||||
// TODO(s.chzhen): Remove this once user sessions support [UserID].
|
||||
loginToUserID map[Login]UserID
|
||||
|
||||
// userIDToUser maps a UserID to a web user. The values must not be nil.
|
||||
// It must be synchronized with loginToUserID, meaning all UserIDs stored in
|
||||
// loginToUserID must also be stored in this map.
|
||||
userIDToUser map[UserID]*User
|
||||
}
|
||||
|
||||
// NewDefaultDB returns the new properly initialized *DefaultDB.
|
||||
func NewDefaultDB() (db *DefaultDB) {
|
||||
return &DefaultDB{
|
||||
mu: &sync.Mutex{},
|
||||
loginToUserID: map[Login]UserID{},
|
||||
userIDToUser: map[UserID]*User{},
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ DB = (*DefaultDB)(nil)
|
||||
|
||||
// All implements the [DB] interface for *DefaultDB.
|
||||
func (db *DefaultDB) All(ctx context.Context) (users []*User, err error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
if len(db.userIDToUser) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
users = slices.SortedStableFunc(
|
||||
maps.Values(db.userIDToUser),
|
||||
func(a, b *User) (res int) {
|
||||
// TODO(s.chzhen): Consider adding a custom comparer.
|
||||
return cmp.Compare(a.Login, b.Login)
|
||||
},
|
||||
)
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// ByLogin implements the [DB] interface for *DefaultDB.
|
||||
func (db *DefaultDB) ByLogin(ctx context.Context, login Login) (u *User, err error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
id, ok := db.loginToUserID[login]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
u, ok = db.userIDToUser[id]
|
||||
if !ok {
|
||||
// Should not happen.
|
||||
panic(fmt.Errorf("no web user present with login %q", login))
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ByUUID implements the [DB] interface for *DefaultDB.
|
||||
func (db *DefaultDB) ByUUID(ctx context.Context, id UserID) (u *User, err error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
u, ok := db.userIDToUser[id]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Create implements the [DB] interface for *DefaultDB.
|
||||
func (db *DefaultDB) Create(ctx context.Context, u *User) (err error) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
if u.ID == (UserID{}) {
|
||||
return fmt.Errorf("userid: %w", errors.ErrEmptyValue)
|
||||
}
|
||||
|
||||
_, ok := db.userIDToUser[u.ID]
|
||||
if ok {
|
||||
return fmt.Errorf("userid: %w", errors.ErrDuplicated)
|
||||
}
|
||||
|
||||
_, ok = db.loginToUserID[u.Login]
|
||||
if ok {
|
||||
return fmt.Errorf("login: %w", errors.ErrDuplicated)
|
||||
}
|
||||
|
||||
db.userIDToUser[u.ID] = u
|
||||
db.loginToUserID[u.Login] = u.ID
|
||||
|
||||
return nil
|
||||
}
|
||||
83
internal/aghuser/db_test.go
Normal file
83
internal/aghuser/db_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package aghuser_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
db := aghuser.NewDefaultDB()
|
||||
|
||||
const (
|
||||
userWithIDPassRaw = "user_with_id_password"
|
||||
userSecondPassRaw = "user_second_password"
|
||||
)
|
||||
|
||||
userWithIDPassHash, err := bcrypt.GenerateFromPassword(
|
||||
[]byte(userWithIDPassRaw),
|
||||
bcrypt.DefaultCost,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
userSecondPassHash, err := bcrypt.GenerateFromPassword(
|
||||
[]byte(userSecondPassRaw),
|
||||
bcrypt.DefaultCost,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
userWithIDPass := aghuser.NewDefaultPassword(string(userWithIDPassHash))
|
||||
userSecondPass := aghuser.NewDefaultPassword(string(userSecondPassHash))
|
||||
|
||||
var (
|
||||
userWithID = &aghuser.User{
|
||||
ID: aghuser.MustNewUserID(),
|
||||
Login: "user_with_id",
|
||||
Password: userWithIDPass,
|
||||
}
|
||||
userSecond = &aghuser.User{
|
||||
ID: aghuser.MustNewUserID(),
|
||||
Login: "user_second",
|
||||
Password: userSecondPass,
|
||||
}
|
||||
userDuplicateLogin = &aghuser.User{
|
||||
ID: aghuser.MustNewUserID(),
|
||||
Login: userWithID.Login,
|
||||
Password: userWithIDPass,
|
||||
}
|
||||
)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
err = db.Create(ctx, userWithID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Create(ctx, userSecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Create(ctx, userDuplicateLogin)
|
||||
assert.ErrorIs(t, err, errors.ErrDuplicated)
|
||||
|
||||
got, err := db.ByUUID(ctx, userWithID.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userWithID, got)
|
||||
assert.True(t, got.Password.Authenticate(ctx, userWithIDPassRaw))
|
||||
|
||||
got, err = db.ByLogin(ctx, userSecond.Login)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, userSecond, got)
|
||||
assert.True(t, got.Password.Authenticate(ctx, userSecondPassRaw))
|
||||
|
||||
users, err := db.All(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, users, 2)
|
||||
assert.Equal(t, []*aghuser.User{userSecond, userWithID}, users)
|
||||
}
|
||||
44
internal/aghuser/user.go
Normal file
44
internal/aghuser/user.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Package aghuser contains types and logic for dealing with AdGuard Home's web
|
||||
// users.
|
||||
package aghuser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserID is the type for the unique IDs of web users.
|
||||
type UserID uuid.UUID
|
||||
|
||||
// NewUserID returns a new web user unique identifier. Any error returned is an
|
||||
// error from the cryptographic randomness reader.
|
||||
func NewUserID() (uid UserID, err error) {
|
||||
uuidv7, err := uuid.NewV7()
|
||||
|
||||
return UserID(uuidv7), err
|
||||
}
|
||||
|
||||
// MustNewUserID is a wrapper around [NewUserID] that panics if there is an
|
||||
// error. It is currently only used in tests.
|
||||
func MustNewUserID() (uid UserID) {
|
||||
uid, err := NewUserID()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
|
||||
}
|
||||
|
||||
return uid
|
||||
}
|
||||
|
||||
// User represents a web user.
|
||||
type User struct {
|
||||
// ID is the unique identifier for the web user. It must not be empty.
|
||||
ID UserID
|
||||
|
||||
// Login is the login name of the web user. It must not be empty.
|
||||
Login Login
|
||||
|
||||
// Password stores the password information for the web user. It must not
|
||||
// be nil.
|
||||
Password Password
|
||||
}
|
||||
@@ -496,6 +496,11 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||
return p.ShallowClone(), ok
|
||||
}
|
||||
|
||||
foundMAC := s.dhcp.MACByIP(ip)
|
||||
if foundMAC != nil {
|
||||
return s.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
p = s.index.findByIPWithoutZone(ip)
|
||||
if p != nil {
|
||||
return p.ShallowClone(), true
|
||||
@@ -682,6 +687,13 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter
|
||||
c, ok = s.index.findByIP(addr)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
foundMAC := s.dhcp.MACByIP(addr)
|
||||
if foundMAC != nil {
|
||||
c, ok = s.FindByMAC(foundMAC)
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
s.logger.Debug("no client filtering settings found", "clientid", id, "addr", addr)
|
||||
|
||||
|
||||
@@ -78,11 +78,11 @@ func TestHostnameToHashes(t *testing.T) {
|
||||
wantLen: 2,
|
||||
}, {
|
||||
name: "private_domain_v2",
|
||||
host: "foo.blogspot.co.uk",
|
||||
wantLen: 4,
|
||||
host: "foo.dyndns.org",
|
||||
wantLen: 3,
|
||||
}, {
|
||||
name: "sub_private_domain_v2",
|
||||
host: "bar.foo.blogspot.co.uk",
|
||||
host: "bar.foo.dyndns.org",
|
||||
wantLen: 4,
|
||||
}}
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ func parseConfig() (err error) {
|
||||
}
|
||||
|
||||
// Do not wrap the error because it's informative enough as is.
|
||||
return setContextTLSCipherIDs()
|
||||
return validateTLSCipherIDs(config.TLS.OverrideTLSCiphers)
|
||||
}
|
||||
|
||||
// validateConfig returns error if the configuration is invalid.
|
||||
@@ -721,21 +721,15 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setContextTLSCipherIDs sets the TLS cipher suite IDs to use.
|
||||
func setContextTLSCipherIDs() (err error) {
|
||||
if len(config.TLS.OverrideTLSCiphers) == 0 {
|
||||
log.Info("tls: using default ciphers")
|
||||
|
||||
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
|
||||
|
||||
// validateTLSCipherIDs validates the custom TLS cipher suite IDs.
|
||||
func validateTLSCipherIDs(cipherIDs []string) (err error) {
|
||||
if len(cipherIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
|
||||
|
||||
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||
_, err = aghtls.ParseCiphers(cipherIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing override ciphers: %w", err)
|
||||
return fmt.Errorf("override_tls_ciphers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -38,6 +38,8 @@ const (
|
||||
)
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
//
|
||||
// TODO(s.chzhen): Remove this after refactoring.
|
||||
func onConfigModified() {
|
||||
err := config.write(globalContext.tls)
|
||||
if err != nil {
|
||||
@@ -120,14 +122,15 @@ func initDNS(
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
tlsMgr,
|
||||
baseLogger,
|
||||
)
|
||||
}
|
||||
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf and l still must not
|
||||
// be nil, in other cases all the arguments also must not be nil. It also must
|
||||
// not be called unless [config] and [globalContext] are initialized.
|
||||
// proxy, none of the arguments are required, but tlsConf, tlsMgr and l still
|
||||
// must not be nil, in other cases all the arguments also must not be nil. It
|
||||
// also must not be called unless [config] and [globalContext] are initialized.
|
||||
//
|
||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||
func initDNSServer(
|
||||
@@ -138,6 +141,7 @@ func initDNSServer(
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
tlsMgr *tlsManager,
|
||||
l *slog.Logger,
|
||||
) (err error) {
|
||||
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
@@ -166,6 +170,7 @@ func initDNSServer(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsConf,
|
||||
tlsMgr,
|
||||
httpReg,
|
||||
globalContext.clients.storage,
|
||||
)
|
||||
@@ -236,11 +241,12 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
|
||||
}
|
||||
|
||||
// newServerConfig converts values from the configuration file into the internal
|
||||
// DNS server configuration. All arguments must not be nil.
|
||||
// DNS server configuration. All arguments must not be nil, except for httpReg.
|
||||
func newServerConfig(
|
||||
dnsConf *dnsConfig,
|
||||
clientSrcConf *clientSourcesConfig,
|
||||
tlsConf *tlsConfigSettings,
|
||||
tlsMgr *tlsManager,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
clientsContainer dnsforward.ClientsContainer,
|
||||
) (newConf *dnsforward.ServerConfig, err error) {
|
||||
@@ -256,7 +262,7 @@ func newServerConfig(
|
||||
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: globalContext.tlsRoots,
|
||||
TLSv12Roots: tlsMgr.rootCerts,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
|
||||
@@ -3,7 +3,6 @@ package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
@@ -81,10 +79,6 @@ type homeContext struct {
|
||||
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||
pidFileName string // PID file name. Empty if no PID file was created.
|
||||
controlLock sync.Mutex
|
||||
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
|
||||
|
||||
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
tlsCipherIDs []uint16
|
||||
|
||||
// firstRun, if true, tells AdGuard Home to only start the web interface
|
||||
// service, and only serve the first-run APIs.
|
||||
@@ -142,7 +136,6 @@ func Main(clientBuildFS fs.FS) {
|
||||
func setupContext(opts options) (err error) {
|
||||
globalContext.firstRun = detectFirstRun()
|
||||
|
||||
globalContext.tlsRoots = aghtls.SystemRootCAs()
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
if !opts.noEtcHosts {
|
||||
@@ -274,18 +267,13 @@ func setupOpts(opts options) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initContextClients initializes Context clients and related fields.
|
||||
// initContextClients initializes Context clients and related fields. All
|
||||
// arguments must not be nil.
|
||||
func initContextClients(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
sigHdlr *signalHandler,
|
||||
) (err error) {
|
||||
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
config.DHCP.WorkDir = globalContext.workDir
|
||||
config.DHCP.DataDir = globalContext.getDataDir()
|
||||
@@ -358,11 +346,13 @@ func setupBindOpts(opts options) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDNSFilteringConf sets up DNS filtering configuration settings.
|
||||
// setupDNSFilteringConf sets up DNS filtering configuration settings. All
|
||||
// arguments must not be nil.
|
||||
func setupDNSFilteringConf(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
conf *filtering.Config,
|
||||
tlsMgr *tlsManager,
|
||||
) (err error) {
|
||||
const (
|
||||
dnsTimeout = 3 * time.Second
|
||||
@@ -388,7 +378,7 @@ func setupDNSFilteringConf(
|
||||
conf.Filters = slices.Clone(config.Filters)
|
||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
conf.UserRules = slices.Clone(config.UserRules)
|
||||
conf.HTTPClient = httpClient()
|
||||
conf.HTTPClient = httpClient(tlsMgr)
|
||||
|
||||
cacheTime := time.Duration(conf.CacheTime) * time.Minute
|
||||
|
||||
@@ -630,6 +620,23 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr)
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
|
||||
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: tlsMgrLogger,
|
||||
configModified: onConfigModified,
|
||||
tlsSettings: config.TLS,
|
||||
servePlainDNS: config.DNS.ServePlainDNS,
|
||||
})
|
||||
if err != nil {
|
||||
tlsMgrLogger.ErrorContext(ctx, "initializing", slogutil.KeyError, err)
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
globalContext.tls = tlsMgr
|
||||
|
||||
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr)
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
@@ -642,7 +649,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(ctx, slogLogger, opts, upd)
|
||||
cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr)
|
||||
|
||||
if !globalContext.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -664,19 +671,14 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
globalContext.auth, err = initUsers()
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
|
||||
tlsMgr, err := newTLSManager(ctx, tlsMgrLogger, config.TLS, config.DNS.ServePlainDNS)
|
||||
if err != nil {
|
||||
log.Error("initializing tls: %s", err)
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
globalContext.tls = tlsMgr
|
||||
sigHdlr.addTLSManager(tlsMgr)
|
||||
|
||||
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||
fatalOnError(err)
|
||||
|
||||
globalContext.web = web
|
||||
|
||||
tlsMgr.setWebAPI(web)
|
||||
sigHdlr.addTLSManager(tlsMgr)
|
||||
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||
fatalOnError(err)
|
||||
|
||||
@@ -706,7 +708,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
}
|
||||
|
||||
globalContext.web.start(ctx)
|
||||
web.start(ctx)
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-done
|
||||
@@ -1058,8 +1060,15 @@ type jsonError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits. l must not be nil.
|
||||
func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updater.Updater) {
|
||||
// cmdlineUpdate updates current application and exits. l and tlsMgr must not
|
||||
// be nil.
|
||||
func cmdlineUpdate(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
opts options,
|
||||
upd *updater.Updater,
|
||||
tlsMgr *tlsManager,
|
||||
) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
@@ -1069,7 +1078,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l)
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, tlsMgr, l)
|
||||
fatalOnError(err)
|
||||
|
||||
l.InfoContext(ctx, "performing update via cli")
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
|
||||
// httpClient returns a new HTTP client that uses the AdGuard Home's own DNS
|
||||
// server for resolving hostnames. The resulting client should not be used
|
||||
// until [Context.dnsServer] is initialized.
|
||||
// until [Context.dnsServer] is initialized. tlsMgr must not be nil.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This is rather messy. Refactor.
|
||||
func httpClient() (c *http.Client) {
|
||||
func httpClient(tlsMgr *tlsManager) (c *http.Client) {
|
||||
// Do not use Context.dnsServer.DialContext directly in the struct literal
|
||||
// below, since Context.dnsServer may be nil when this function is called.
|
||||
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
@@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
|
||||
DialContext: dialContext,
|
||||
Proxy: httpProxy,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
RootCAs: tlsMgr.rootCerts,
|
||||
CipherSuites: tlsMgr.customCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -41,6 +43,22 @@ type tlsManager struct {
|
||||
// certLastMod is the last modification time of the certificate file.
|
||||
certLastMod time.Time
|
||||
|
||||
// rootCerts is a pool of root CAs for TLSv1.2.
|
||||
rootCerts *x509.CertPool
|
||||
|
||||
// web is the web UI and API server. It must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Temporary cyclic dependency due to ongoing refactoring.
|
||||
// Resolve it.
|
||||
web *webAPI
|
||||
|
||||
// configModified is called when the TLS configuration is changed via an
|
||||
// HTTP request.
|
||||
configModified func()
|
||||
|
||||
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
customCipherIDs []uint16
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
|
||||
@@ -48,21 +66,50 @@ type tlsManager struct {
|
||||
servePlainDNS bool
|
||||
}
|
||||
|
||||
// tlsManagerConfig contains the settings for initializing the TLS manager.
|
||||
type tlsManagerConfig struct {
|
||||
// logger is used for logging the operation of the TLS Manager. It must not
|
||||
// be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// configModified is called when the TLS configuration is changed via an
|
||||
// HTTP request. It must not be nil.
|
||||
configModified func()
|
||||
|
||||
// tlsSettings contains the TLS configuration settings.
|
||||
tlsSettings tlsConfigSettings
|
||||
|
||||
// servePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
servePlainDNS bool
|
||||
}
|
||||
|
||||
// newTLSManager initializes the manager of TLS configuration. m is always
|
||||
// non-nil while any returned error indicates that the TLS configuration isn't
|
||||
// valid. Thus TLS may be initialized later, e.g. via the web UI. logger must
|
||||
// not be nil.
|
||||
func newTLSManager(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
conf tlsConfigSettings,
|
||||
servePlainDNS bool,
|
||||
) (m *tlsManager, err error) {
|
||||
// valid. Thus TLS may be initialized later, e.g. via the web UI. conf must
|
||||
// not be nil. Note that [tlsManager.web] must be initialized later on by using
|
||||
// [tlsManager.setWebAPI].
|
||||
func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
logger: logger,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf,
|
||||
servePlainDNS: servePlainDNS,
|
||||
logger: conf.logger,
|
||||
configModified: conf.configModified,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
}
|
||||
|
||||
m.rootCerts = aghtls.SystemRootCAs()
|
||||
|
||||
if len(conf.tlsSettings.OverrideTLSCiphers) > 0 {
|
||||
m.customCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
|
||||
if err != nil {
|
||||
// Should not happen because upstreams are already validated. See
|
||||
// [validateTLSCipherIDs].
|
||||
panic(err)
|
||||
}
|
||||
|
||||
m.logger.InfoContext(ctx, "overriding ciphers", "ciphers", config.TLS.OverrideTLSCiphers)
|
||||
} else {
|
||||
m.logger.InfoContext(ctx, "using default ciphers")
|
||||
}
|
||||
|
||||
if m.conf.Enabled {
|
||||
@@ -79,6 +126,15 @@ func newTLSManager(
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// setWebAPI stores the provided web API. It must be called before
|
||||
// [tlsManager.start], [tlsManager.reload], [tlsManager.handleTLSConfigure], or
|
||||
// [tlsManager.validateTLSSettings].
|
||||
//
|
||||
// TODO(s.chzhen): Remove it once cyclic dependency is resolved.
|
||||
func (m *tlsManager) setWebAPI(webAPI *webAPI) {
|
||||
m.web = webAPI
|
||||
}
|
||||
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
func (m *tlsManager) load(ctx context.Context) (err error) {
|
||||
err = m.loadTLSConf(ctx, &m.conf, m.status)
|
||||
@@ -126,7 +182,7 @@ func (m *tlsManager) start(_ context.Context) {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
m.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts the TLS manager.
|
||||
@@ -178,7 +234,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
m.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// reconfigureDNSServer updates the DNS server configuration using the stored
|
||||
@@ -191,6 +247,7 @@ func (m *tlsManager) reconfigureDNSServer() (err error) {
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsConf,
|
||||
m,
|
||||
httpRegister,
|
||||
globalContext.clients.storage,
|
||||
)
|
||||
@@ -368,6 +425,8 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
setts, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
@@ -379,7 +438,9 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if err = validateTLSSettings(setts); err != nil {
|
||||
if err = m.validateTLSSettings(setts); err != nil {
|
||||
m.logger.InfoContext(ctx, "validating tls settings", slogutil.KeyError, err)
|
||||
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
@@ -388,7 +449,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = m.loadTLSConf(r.Context(), &setts.tlsConfigSettings, status)
|
||||
_ = m.loadTLSConf(ctx, &setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
@@ -458,7 +519,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if err = validateTLSSettings(req); err != nil {
|
||||
if err = m.validateTLSSettings(req); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
@@ -489,7 +550,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
}()
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
m.configModified()
|
||||
|
||||
err = m.reconfigureDNSServer()
|
||||
if err != nil {
|
||||
@@ -516,36 +577,54 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
m.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// validateTLSSettings returns error if the setts are not valid.
|
||||
func validateTLSSettings(setts tlsConfigSettingsExt) (err error) {
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
func (m *tlsManager) validateTLSSettings(setts tlsConfigSettingsExt) (err error) {
|
||||
if !setts.Enabled {
|
||||
if setts.ServePlainDNS == aghalg.NBFalse {
|
||||
// TODO(a.garipov): Support full disabling of all DNS.
|
||||
return errors.Error("plain DNS is required in case encryption protocols are disabled")
|
||||
}
|
||||
} else if setts.ServePlainDNS == aghalg.NBFalse {
|
||||
// TODO(a.garipov): Support full disabling of all DNS.
|
||||
return errors.Error("plain DNS is required in case encryption protocols are disabled")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
||||
return fmt.Errorf("port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
|
||||
var (
|
||||
tlsConf tlsConfigSettings
|
||||
webAPIAddr netip.Addr
|
||||
webAPIPort uint16
|
||||
plainDNSPort uint16
|
||||
)
|
||||
|
||||
func() {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
tlsConf = config.TLS
|
||||
webAPIAddr = config.HTTPConfig.Address.Addr()
|
||||
webAPIPort = config.HTTPConfig.Address.Port()
|
||||
plainDNSPort = config.DNS.Port
|
||||
}()
|
||||
|
||||
err = validatePorts(
|
||||
tcpPort(webAPIPort),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(plainDNSPort),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return m.checkPortAvailability(tlsConf, setts.tlsConfigSettings, webAPIAddr)
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
@@ -557,10 +636,11 @@ func validatePorts(
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
bindPort,
|
||||
dohPort,
|
||||
dotPort,
|
||||
dnscryptTCPPort,
|
||||
tcpPort(dnsPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
@@ -569,7 +649,7 @@ func validatePorts(
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
addPorts(udpPorts, dnsPort, doqPort)
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
@@ -604,7 +684,7 @@ func (m *tlsManager) validateCertChain(
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
DNSName: srvName,
|
||||
Roots: globalContext.tlsRoots,
|
||||
Roots: m.rootCerts,
|
||||
Intermediates: pool,
|
||||
}
|
||||
_, err = main.Verify(opts)
|
||||
@@ -615,6 +695,67 @@ func (m *tlsManager) validateCertChain(
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPortAvailability checks [tlsConfigSettings.PortHTTPS],
|
||||
// [tlsConfigSettings.PortDNSOverTLS], and [tlsConfigSettings.PortDNSOverQUIC]
|
||||
// are available for use. It checks the current configuration and, if needed,
|
||||
// attempts to bind to the port. The function returns human-readable error
|
||||
// messages for the frontend. This is best-effort check to prevent an "address
|
||||
// already in use" error.
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func (m *tlsManager) checkPortAvailability(
|
||||
currConf tlsConfigSettings,
|
||||
newConf tlsConfigSettings,
|
||||
addr netip.Addr,
|
||||
) (err error) {
|
||||
const (
|
||||
networkTCP = "tcp"
|
||||
networkUDP = "udp"
|
||||
|
||||
protoHTTPS = "HTTPS"
|
||||
protoDoT = "DNS-over-TLS"
|
||||
protoDoQ = "DNS-over-QUIC"
|
||||
)
|
||||
|
||||
needBindingCheck := []struct {
|
||||
network string
|
||||
proto string
|
||||
currPort uint16
|
||||
newPort uint16
|
||||
}{{
|
||||
network: networkTCP,
|
||||
proto: protoHTTPS,
|
||||
currPort: currConf.PortHTTPS,
|
||||
newPort: newConf.PortHTTPS,
|
||||
}, {
|
||||
network: networkTCP,
|
||||
proto: protoDoT,
|
||||
currPort: currConf.PortDNSOverTLS,
|
||||
newPort: newConf.PortDNSOverTLS,
|
||||
}, {
|
||||
network: networkUDP,
|
||||
proto: protoDoQ,
|
||||
currPort: currConf.PortDNSOverQUIC,
|
||||
newPort: newConf.PortDNSOverQUIC,
|
||||
}}
|
||||
|
||||
var errs []error
|
||||
for _, v := range needBindingCheck {
|
||||
port := v.newPort
|
||||
if v.currPort == port {
|
||||
continue
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(addr, port)
|
||||
err = aghnet.CheckPort(v.network, addrPort)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("port %d for %s is not available", port, v.proto))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// errNoIPInCert is the error that is returned from [tlsManager.parseCertChain]
|
||||
// if the leaf certificate doesn't contain IPs.
|
||||
const errNoIPInCert errors.Error = `certificates has no IP addresses; ` +
|
||||
@@ -718,27 +859,12 @@ func (m *tlsManager) validateCertificates(
|
||||
) (err error) {
|
||||
// Check only the public certificate separately from the key.
|
||||
if len(certChain) > 0 {
|
||||
var certs []*x509.Certificate
|
||||
certs, status.ValidCert, err = m.parseCertChain(ctx, certChain)
|
||||
if !status.ValidCert {
|
||||
var ok bool
|
||||
ok, err = m.validateCertificate(ctx, status, certChain, serverName)
|
||||
if !ok {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
mainCert := certs[0]
|
||||
status.Subject = mainCert.Subject.String()
|
||||
status.Issuer = mainCert.Issuer.String()
|
||||
status.NotAfter = mainCert.NotAfter
|
||||
status.NotBefore = mainCert.NotBefore
|
||||
status.DNSNames = mainCert.DNSNames
|
||||
|
||||
if chainErr := m.validateCertChain(ctx, certs, serverName); chainErr != nil {
|
||||
// Let self-signed certs through and don't return this error to set
|
||||
// its message into the status.WarningValidation afterwards.
|
||||
err = chainErr
|
||||
} else {
|
||||
status.ValidChain = true
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the private key by parsing it.
|
||||
@@ -766,6 +892,41 @@ func (m *tlsManager) validateCertificates(
|
||||
return err
|
||||
}
|
||||
|
||||
// validateCertificate processes certificate data. status must not be nil, as
|
||||
// it is used to accumulate the validation results. Other parameters are
|
||||
// optional. If ok is true, the returned error, if any, is not critical.
|
||||
func (m *tlsManager) validateCertificate(
|
||||
ctx context.Context,
|
||||
status *tlsConfigStatus,
|
||||
certChain []byte,
|
||||
serverName string,
|
||||
) (ok bool, err error) {
|
||||
var certs []*x509.Certificate
|
||||
certs, status.ValidCert, err = m.parseCertChain(ctx, certChain)
|
||||
if !status.ValidCert {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return false, err
|
||||
}
|
||||
|
||||
mainCert := certs[0]
|
||||
status.Subject = mainCert.Subject.String()
|
||||
status.Issuer = mainCert.Issuer.String()
|
||||
status.NotAfter = mainCert.NotAfter
|
||||
status.NotBefore = mainCert.NotBefore
|
||||
status.DNSNames = mainCert.DNSNames
|
||||
|
||||
err = m.validateCertChain(ctx, certs, serverName)
|
||||
if err != nil {
|
||||
// Let self-signed certs through and don't return this error to set
|
||||
// its message into the status.WarningValidation afterwards.
|
||||
return true, err
|
||||
}
|
||||
|
||||
status.ValidChain = true
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Key types.
|
||||
const (
|
||||
keyTypeECDSA = "ECDSA"
|
||||
@@ -828,17 +989,18 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
var key []byte
|
||||
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||
}
|
||||
if data.PrivateKey == "" {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
data.PrivateKey = string(key)
|
||||
if data.PrivateKeyPath != "" {
|
||||
return data, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||
}
|
||||
|
||||
data.PrivateKey = string(key)
|
||||
if data.PrivateKeyPath != "" {
|
||||
return data, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(s.chzhen): Consider moving to testdata.
|
||||
var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
@@ -66,7 +67,11 @@ func TestValidateCertificates(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
|
||||
m, err := newTLSManager(ctx, logger, tlsConfigSettings{}, false)
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("bad_certificate", func(t *testing.T) {
|
||||
@@ -112,7 +117,6 @@ func TestValidateCertificates(t *testing.T) {
|
||||
// - [homeContext.clients.storage]
|
||||
// - [homeContext.dnsServer]
|
||||
// - [homeContext.mux]
|
||||
// - [homeContext.web]
|
||||
//
|
||||
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
|
||||
// variables. Make tests that use this helper concurrent.
|
||||
@@ -123,14 +127,12 @@ func storeGlobals(tb testing.TB) {
|
||||
storage := globalContext.clients.storage
|
||||
dnsServer := globalContext.dnsServer
|
||||
mux := globalContext.mux
|
||||
web := globalContext.web
|
||||
|
||||
tb.Cleanup(func() {
|
||||
config = prevConfig
|
||||
globalContext.clients.storage = storage
|
||||
globalContext.dnsServer = dnsServer
|
||||
globalContext.mux = mux
|
||||
globalContext.web = web
|
||||
})
|
||||
}
|
||||
|
||||
@@ -221,9 +223,6 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
snBefore int64 = 1
|
||||
snAfter int64 = 2
|
||||
@@ -236,15 +235,25 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
certDER, key := newCertAndKey(t, snBefore)
|
||||
writeCertAndKey(t, certDER, certPath, key, keyPath)
|
||||
|
||||
m, err := newTLSManager(ctx, logger, tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := &tlsConfigSettings{}
|
||||
m.WriteDiskConfig(conf)
|
||||
assertCertSerialNumber(t, conf, snBefore)
|
||||
@@ -265,13 +274,18 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
m, err := newTLSManager(ctx, logger, tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -291,26 +305,42 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
func TestValidateTLSSettings(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
var (
|
||||
logger = slogutil.NewDiscardLogger()
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err error
|
||||
)
|
||||
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, ln.Close)
|
||||
|
||||
addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr())
|
||||
|
||||
busyPort := addr.Port
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
tcpLn, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, tcpLn.Close)
|
||||
|
||||
tcpAddr := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpLn.Addr())
|
||||
busyTCPPort := tcpAddr.Port
|
||||
|
||||
udpLn, err := net.ListenPacket("udp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, udpLn.Close)
|
||||
|
||||
udpAddr := testutil.RequireTypeAssert[*net.UDPAddr](t, udpLn.LocalAddr())
|
||||
busyUDPPort := udpAddr.Port
|
||||
|
||||
testCases := []struct {
|
||||
setts tlsConfigSettingsExt
|
||||
name string
|
||||
@@ -329,11 +359,29 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortHTTPS: uint16(busyPort),
|
||||
PortHTTPS: uint16(busyTCPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_port",
|
||||
wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort),
|
||||
name: "busy_https_port",
|
||||
wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort),
|
||||
}, {
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortDNSOverTLS: uint16(busyTCPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_dot_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort),
|
||||
}, {
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortDNSOverQUIC: uint16(busyUDPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_doq_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort),
|
||||
}, {
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
@@ -348,7 +396,7 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = validateTLSSettings(tc.setts)
|
||||
err = m.validateTLSSettings(tc.setts)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
@@ -357,26 +405,33 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
var (
|
||||
logger = slogutil.NewDiscardLogger()
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err error
|
||||
)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := newTLSManager(ctx, logger, tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
setts := &tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
@@ -438,9 +493,6 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
|
||||
config.DNS.Port = 0
|
||||
|
||||
@@ -455,15 +507,25 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
writeCertAndKey(t, certDER, certPath, key, keyPath)
|
||||
|
||||
// Initialize the TLS manager and assert its configuration.
|
||||
m, err := newTLSManager(ctx, logger, tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
servePlainDNS: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := &tlsConfigSettings{}
|
||||
m.WriteDiskConfig(conf)
|
||||
assertCertSerialNumber(t, conf, wantSerialNumber)
|
||||
@@ -509,10 +571,10 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
//
|
||||
// TODO(s.chzhen): Remove when [httpsServer.cond] is removed.
|
||||
assert.Eventually(t, func() bool {
|
||||
globalContext.web.httpsServer.condLock.Lock()
|
||||
defer globalContext.web.httpsServer.condLock.Unlock()
|
||||
web.httpsServer.condLock.Lock()
|
||||
defer web.httpsServer.condLock.Unlock()
|
||||
|
||||
cert = globalContext.web.httpsServer.cert
|
||||
cert = web.httpsServer.cert
|
||||
if cert.Leaf == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,10 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
@@ -158,27 +156,6 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
return w
|
||||
}
|
||||
|
||||
// webCheckPortAvailable checks if port, which is considered an HTTPS port, is
|
||||
// available, unless the HTTPS server isn't active.
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func webCheckPortAvailable(port uint16) (ok bool) {
|
||||
if globalContext.web.httpsServer.server != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
|
||||
|
||||
err := aghnet.CheckPort("tcp", addrPort)
|
||||
if err != nil {
|
||||
log.Info("web: warning: checking https port: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
@@ -329,8 +306,8 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
||||
Handler: hdlr,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
RootCAs: web.tlsManager.rootCerts,
|
||||
CipherSuites: web.tlsManager.customCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
@@ -363,8 +340,8 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
|
||||
Addr: address,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: globalContext.tlsRoots,
|
||||
CipherSuites: globalContext.tlsCipherIDs,
|
||||
RootCAs: web.tlsManager.rootCerts,
|
||||
CipherSuites: web.tlsManager.customCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
|
||||
|
||||
Reference in New Issue
Block a user