Compare commits

...

4 Commits

Author SHA1 Message Date
Stanislav Chzhen
53cb84efc0 all: session storage usage 2025-04-22 15:42:12 +03:00
Eugene Burkov
c7c62ad3b6 Pull request 2395: Update all
Merge in DNS/adguard-home from upd-all to master

Squashed commit of the following:

commit c4bba4531813bdeb79536f7f601ade80a16a9163
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Apr 21 18:17:06 2025 +0300

    client: upd filters
2025-04-21 18:46:47 +03:00
Stanislav Chzhen
003e7ce0d5 Pull request 2393: 7773-fix-unencrypted_doh
Updates #7773.

Squashed commit of the following:

commit d9ca09c1d9b251998107fc87bd6daeb5999ea803
Merge: b67a71a7a a8fdf1c55
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 21 15:56:57 2025 +0300

    Merge branch 'master' into 7773-fix-unencrypted_doh

commit b67a71a7a9686d36cbf64a3f7561886bff7d9c5c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 16:01:49 2025 +0300

    home: imp docs

commit dab9b0582ff1ebc4637d5ec1ea3bc81190ed4066
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 15:09:36 2025 +0300

    home: fix unencrypted doh
2025-04-21 16:05:16 +03:00
Stanislav Chzhen
a8fdf1c553 Pull request 2386: AGDNS-2743-aghuser-session
Merge in DNS/adguard-home from AGDNS-2743-aghuser-session to master

Squashed commit of the following:

commit 74fd4bc11eaf784880855fa2c710a747428db146
Merge: 844e865f6 7d479baba
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 18:14:36 2025 +0300

    Merge branch 'master' into AGDNS-2743-aghuser-session

commit 844e865f647efb4de7f057c392894c8f65bab422
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 18 15:18:44 2025 +0300

    aghuser: imp fmt

commit 584288e0a3ddbe6d7ae31c80c22b8f397cfd0cae
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 17 20:16:54 2025 +0300

    aghuser: imp tests

commit ea4c8735585f6d30d6dedf2a40a8dd6b07609d07
Merge: c3fd8fe5e 3521e8ed9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 17 20:10:06 2025 +0300

    Merge branch 'master' into AGDNS-2743-aghuser-session

commit c3fd8fe5eabaf2022a971197c018e140c254006d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 17 15:23:45 2025 +0300

    aghuser: imp tests

commit dfd9aba337227a8d3edc6f5a68f3f039afd1ca0b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 16 21:40:14 2025 +0300

    aghuser: imp code

commit b6e75223bf7960f3a2e94c1a3ed7cc33539b9806
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 14 21:49:20 2025 +0300

    aghuser: imp code

commit 56d6f9d478eec399c376992ffb0f1ca5b797986d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 11 16:58:11 2025 +0300

    aghuser: user db

commit 6fdc2f60bf7f93e72d917abb12af8e4867143b6d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 14:11:22 2025 +0300

    all: upd scripts

commit 575946756f3f622360c5feafe3e721eee010e230
Merge: 7e1fac4ec 1cc6c00e4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 14:05:46 2025 +0300

    Merge branch 'master' into AGDNS-2743-aghuser-session

commit 7e1fac4ecb1bde0013bca3f6b64e82d81a78c9c3
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 10 14:05:35 2025 +0300

    aghuser: session storage

commit acfb040f0bdff501c7304ea100b9faf1c07291ae
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Apr 8 15:54:24 2025 +0300

    aghuser: session
2025-04-18 18:34:10 +03:00
16 changed files with 918 additions and 455 deletions

View File

@@ -28,6 +28,12 @@ export default {
"homepage": "https://badmojr.github.io/1Hosts/",
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_24.txt"
},
"1hosts_pro": {
"name": "1Hosts (Pro)",
"categoryId": "general",
"homepage": "https://badmojr.github.io/1Hosts/",
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_64.txt"
},
"CHN_adrules": {
"name": "CHN: AdRules DNS List",
"categoryId": "regional",

View File

@@ -10,7 +10,8 @@ import (
// Login is the type for web user logins.
type Login string
// NewLogin returns a web user login.
// NewLogin returns a web user login. The length of s must not be greater than
// [math.MaxUint16].
//
// TODO(s.chzhen): Add more constraints as needed.
func NewLogin(s string) (l Login, err error) {

View File

@@ -0,0 +1,35 @@
package aghuser
import (
"crypto/rand"
"time"
)
// SessionToken is the type for the web user session token.
type SessionToken [16]byte
// NewSessionToken returns a cryptographically secure randomly generated web
// user session token. If an error occurs during random generation, it will
// cause the program to crash.
func NewSessionToken() (t SessionToken) {
_, _ = rand.Read(t[:])
return t
}
// Session represents a web user session.
type Session struct {
// Expire indicates when the session will expire.
Expire time.Time
// UserLogin is the login of the web user associated with the session.
//
// TODO(s.chzhen): Remove this field and associate the user by UserID.
UserLogin Login
// Token is the session token.
Token SessionToken
// UserID is the identifier of the web user associated with the session.
UserID UserID
}

View File

@@ -0,0 +1,453 @@
package aghuser
import (
"context"
"encoding/binary"
"fmt"
"log/slog"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
"go.etcd.io/bbolt"
berrors "go.etcd.io/bbolt/errors"
)
// SessionStorage is an interface that defines methods for handling web user
// sessions. All methods must be safe for concurrent use.
//
// TODO(s.chzhen): Add DeleteAll method.
type SessionStorage interface {
// New creates a new session for the web user.
New(ctx context.Context, u *User) (s *Session, err error)
// FindByToken returns the stored session for the web user based on the session
// token.
//
// TODO(s.chzhen): Consider function signature change to reflect the
// in-memory implementation, as it currently always returns nil for error.
FindByToken(ctx context.Context, t SessionToken) (s *Session, err error)
// DeleteByToken removes a stored web user session by the provided token.
DeleteByToken(ctx context.Context, t SessionToken) (err error)
// Close releases the web user sessions database resources.
Close() (err error)
}
// DefaultSessionStorageConfig represents the web user session storage
// configuration structure.
type DefaultSessionStorageConfig struct {
// Logger is used for logging the operation of the session storage. It must
// not be nil.
Logger *slog.Logger
// Clock is used to get the current time. It must not be nil.
Clock timeutil.Clock
// UserDB contains the web user information such as ID, login, and password.
// It must not be nil.
UserDB DB
// DBPath is the path to the database file where session data is stored. It
// must not be empty.
DBPath string
// SessionTTL is the default Time-To-Live duration for web user sessions.
// It specifies how long a session should last and is a required field.
SessionTTL time.Duration
}
// DefaultSessionStorage is the default bbolt database implementation of the
// [SessionStorage] interface.
type DefaultSessionStorage struct {
// db is an instance of the bbolt database where web user sessions are
// stored by [SessionToken] in the [bucketNameSessions] bucket.
db *bbolt.DB
// logger is used for logging the operation of the session storage.
logger *slog.Logger
// mu protects sessions.
mu *sync.Mutex
// clock is used to get the current time.
clock timeutil.Clock
// userDB contains the web user information such as ID, login, and password.
userDB DB
// sessions maps a session token to a web user session.
sessions map[SessionToken]*Session
// sessionTTL is the default Time-To-Live value for web user sessions.
sessionTTL time.Duration
}
// NewDefaultSessionStorage returns the new properly initialized
// *DefaultSessionStorage.
func NewDefaultSessionStorage(
ctx context.Context,
conf *DefaultSessionStorageConfig,
) (ds *DefaultSessionStorage, err error) {
ds = &DefaultSessionStorage{
clock: conf.Clock,
userDB: conf.UserDB,
logger: conf.Logger,
mu: &sync.Mutex{},
sessions: map[SessionToken]*Session{},
sessionTTL: conf.SessionTTL,
}
dbFilename := conf.DBPath
// TODO(s.chzhen): Pass logger with options.
ds.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
if err != nil {
ds.logger.ErrorContext(ctx, "opening db %q: %w", dbFilename, err)
if errors.Is(err, berrors.ErrInvalid) {
const s = "AdGuard Home cannot be initialized due to an incompatible file system.\n" +
"Please read the explanation here: https://adguard-dns.io/kb/adguard-home/getting-started/#limitations"
slogutil.PrintLines(ctx, ds.logger, slog.LevelError, "", s)
}
return nil, err
}
err = ds.loadSessions(ctx)
if err != nil {
return nil, fmt.Errorf("loading sessions: %w", err)
}
return ds, nil
}
// loadSessions loads web user sessions from the bbolt database.
func (ds *DefaultSessionStorage) loadSessions(ctx context.Context) (err error) {
tx, err := ds.db.Begin(true)
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
needRollback := true
defer func() {
if needRollback {
err = errors.WithDeferred(err, tx.Rollback())
}
}()
bkt := tx.Bucket([]byte(bboltBucketSessions))
if bkt == nil {
return nil
}
removed, err := ds.processSessions(ctx, bkt)
if err != nil {
return fmt.Errorf("processing sessions: %w", err)
}
if removed == 0 {
ds.logger.DebugContext(ctx, "loading sessions from db", "stored", len(ds.sessions))
return nil
}
needRollback = false
err = tx.Commit()
if err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
ds.logger.DebugContext(
ctx,
"loading sessions from db",
"stored", len(ds.sessions),
"removed", removed,
)
return nil
}
// processSessions iterates over the sessions bucket and loads or removes
// sessions as needed.
func (ds *DefaultSessionStorage) processSessions(
ctx context.Context,
bkt *bbolt.Bucket,
) (removed int, err error) {
invalidSessions := [][]byte{}
err = bkt.ForEach(ds.bboltSessionHandler(ctx, &invalidSessions))
if err != nil {
return 0, fmt.Errorf("iterating over sessions: %w", err)
}
var errs []error
for _, s := range invalidSessions {
if err = bkt.Delete(s); err != nil {
errs = append(errs, err)
}
}
if err = errors.Join(errs...); err != nil {
return 0, fmt.Errorf("deleting sessions: %w", err)
}
return len(invalidSessions), nil
}
// bboltSessionHandler returns a function for [bbolt.Bucket.ForEach] that
// iterates over stored sessions, deserializes them, and logs any errors
// encountered. The returned error is always nil, as these errors are
// considered non-critical to stop the iteration process.
func (ds *DefaultSessionStorage) bboltSessionHandler(
ctx context.Context,
invalidSessions *[][]byte,
) (fn func(k, v []byte) (err error)) {
now := ds.clock.Now()
return func(k, v []byte) (err error) {
s, err := bboltDecode(v)
if err != nil {
*invalidSessions = append(*invalidSessions, k)
ds.logger.DebugContext(ctx, "deserializing session", slogutil.KeyError, err)
return nil
}
if now.After(s.Expire) {
*invalidSessions = append(*invalidSessions, k)
return nil
}
u, err := ds.userDB.ByLogin(ctx, s.UserLogin)
if err != nil {
// Should not happen, as it currently always returns nil for error.
panic(err)
}
if u == nil {
*invalidSessions = append(*invalidSessions, k)
ds.logger.DebugContext(ctx, "no saved user by name", "name", s.UserLogin)
return nil
}
t := SessionToken(k)
s.Token = t
s.UserID = u.ID
ds.sessions[t] = s
return nil
}
}
// bboltBucketSessions is the name of the bucket storing web user sessions in
// the bbolt database.
const bboltBucketSessions = "sessions-2"
const (
// bboltSessionExpireLen is the length of the expire field in the binary
// entry stored in bbolt.
bboltSessionExpireLen = 4
// bboltSessionNameLen is the length of the name field in the binary entry
// stored in bbolt.
bboltSessionNameLen = 2
)
// bboltDecode deserializes decodes a binary data into a session.
func bboltDecode(data []byte) (s *Session, err error) {
if len(data) < bboltSessionExpireLen+bboltSessionNameLen {
return nil, fmt.Errorf("length of the data is less than expected: got %d", len(data))
}
expireData := data[:bboltSessionExpireLen]
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
nameLen := binary.BigEndian.Uint16(nameLenData)
if len(nameData) != int(nameLen) {
return nil, fmt.Errorf("login: expected length %d, got %d", nameLen, len(nameData))
}
expire := binary.BigEndian.Uint32(expireData)
return &Session{
Expire: time.Unix(int64(expire), 0),
UserLogin: Login(nameData),
}, nil
}
// bboltEncode serializes a session properties into a binary data.
func bboltEncode(s *Session) (data []byte) {
data = make([]byte, bboltSessionExpireLen+bboltSessionNameLen+len(s.UserLogin))
expireData := data[:bboltSessionExpireLen]
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
expire := uint32(s.Expire.Unix())
binary.BigEndian.PutUint32(expireData, expire)
binary.BigEndian.PutUint16(nameLenData, uint16(len(s.UserLogin)))
copy(nameData, []byte(s.UserLogin))
return data
}
// type check
var _ SessionStorage = (*DefaultSessionStorage)(nil)
// New implements the [SessionStorage] interface for *DefaultSessionStorage.
func (ds *DefaultSessionStorage) New(ctx context.Context, u *User) (s *Session, err error) {
s = &Session{
Token: NewSessionToken(),
UserID: u.ID,
UserLogin: u.Login,
Expire: ds.clock.Now().Add(ds.sessionTTL),
}
err = ds.store(s)
if err != nil {
return nil, fmt.Errorf("storing session: %w", err)
}
ds.mu.Lock()
defer ds.mu.Unlock()
ds.sessions[s.Token] = s
return s, nil
}
// store saves a web user session in the bbolt database.
func (ds *DefaultSessionStorage) store(s *Session) (err error) {
tx, err := ds.db.Begin(true)
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
needRollback := true
defer func() {
if needRollback {
err = errors.WithDeferred(err, tx.Rollback())
}
}()
bkt, err := tx.CreateBucketIfNotExists([]byte(bboltBucketSessions))
if err != nil {
return fmt.Errorf("creating bucket: %w", err)
}
err = bkt.Put(s.Token[:], bboltEncode(s))
if err != nil {
return fmt.Errorf("putting data: %w", err)
}
needRollback = false
err = tx.Commit()
if err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
// FindByToken implements the [SessionStorage] interface for
// *DefaultSessionStorage.
func (ds *DefaultSessionStorage) FindByToken(
ctx context.Context,
t SessionToken,
) (s *Session, err error) {
ds.mu.Lock()
defer ds.mu.Unlock()
s, ok := ds.sessions[t]
if !ok {
return nil, nil
}
now := ds.clock.Now()
if now.After(s.Expire) {
err = ds.deleteByToken(ctx, t)
if err != nil {
return nil, fmt.Errorf("expired session: %w", err)
}
return nil, nil
}
return s, nil
}
// DeleteByToken implements the [SessionStorage] interface for
// *DefaultSessionStorage.
func (ds *DefaultSessionStorage) DeleteByToken(ctx context.Context, t SessionToken) (err error) {
ds.mu.Lock()
defer ds.mu.Unlock()
// Don't wrap the error because it's informative enough as is.
return ds.deleteByToken(ctx, t)
}
// deleteByToken removes stored session by token. ds.mu is expected to be
// locked.
func (ds *DefaultSessionStorage) deleteByToken(ctx context.Context, t SessionToken) (err error) {
err = ds.remove(ctx, t)
if err != nil {
ds.logger.ErrorContext(ctx, "deleting session", slogutil.KeyError, err)
return err
}
delete(ds.sessions, t)
return nil
}
// remove deletes a web user session from the bbolt database.
func (ds *DefaultSessionStorage) remove(ctx context.Context, t SessionToken) (err error) {
tx, err := ds.db.Begin(true)
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
needRollback := true
defer func() {
if needRollback {
err = errors.WithDeferred(err, tx.Rollback())
}
}()
bkt := tx.Bucket([]byte(bboltBucketSessions))
if bkt == nil {
return errors.Error("no bucket")
}
err = bkt.Delete(t[:])
if err != nil {
return fmt.Errorf("removing data: %w", err)
}
needRollback = false
err = tx.Commit()
if err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
ds.logger.DebugContext(ctx, "removed session from db")
return err
}
// Close implements the [SessionStorage] interface for *DefaultSessionStorage.
func (ds *DefaultSessionStorage) Close() (err error) {
err = ds.db.Close()
if err != nil {
return fmt.Errorf("closing db: %w", err)
}
return nil
}

View File

@@ -0,0 +1,162 @@
package aghuser_test
import (
"context"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/faketime"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// addSession is a helper function that saves and returns a session for a newly
// generated [aghuser.User] by login.
func addSession(
tb testing.TB,
ctx context.Context,
ds aghuser.SessionStorage,
login aghuser.Login,
) (s *aghuser.Session) {
tb.Helper()
s, err := ds.New(ctx, &aghuser.User{
ID: aghuser.MustNewUserID(),
Login: login,
})
require.NoError(tb, err)
require.NotNil(tb, s)
var got *aghuser.Session
got, err = ds.FindByToken(ctx, s.Token)
require.NoError(tb, err)
require.NotNil(tb, got)
assert.Equal(tb, login, got.UserLogin)
return s
}
func TestDefaultSessionStorage(t *testing.T) {
const (
userLoginFirst aghuser.Login = "user_one"
userLoginSecond aghuser.Login = "user_two"
)
var (
ctx = testutil.ContextWithTimeout(t, testTimeout)
logger = slogutil.NewDiscardLogger()
)
const (
sessionTTL = time.Minute
timeStep = time.Second
)
// Set up a mock clock to test expired sessions. Each call to [clock.Now]
// will return the [date] incremented by [timeStep].
date := time.Now()
clock := &faketime.Clock{
OnNow: func() (now time.Time) {
date = date.Add(timeStep)
return date
},
}
dbFile, err := os.CreateTemp(t.TempDir(), "sessions.db")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, dbFile.Close)
userDB := aghuser.NewDefaultDB()
err = userDB.Create(ctx, &aghuser.User{
Login: userLoginFirst,
ID: aghuser.MustNewUserID(),
})
require.NoError(t, err)
err = userDB.Create(ctx, &aghuser.User{
Login: userLoginSecond,
ID: aghuser.MustNewUserID(),
})
require.NoError(t, err)
var (
ds *aghuser.DefaultSessionStorage
sessionFirst *aghuser.Session
sessionSecond *aghuser.Session
)
require.True(t, t.Run("prepare_session_storage", func(t *testing.T) {
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
Clock: clock,
UserDB: userDB,
Logger: logger,
DBPath: dbFile.Name(),
SessionTTL: sessionTTL,
})
require.NoError(t, err)
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
// Advance time to ensure the first session expires before creating the
// second session.
date = date.Add(time.Hour)
sessionSecond = addSession(t, ctx, ds, userLoginSecond)
err = ds.Close()
require.NoError(t, err)
}))
require.True(t, t.Run("load_sessions", func(t *testing.T) {
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
Clock: clock,
UserDB: userDB,
Logger: logger,
DBPath: dbFile.Name(),
SessionTTL: sessionTTL,
})
require.NoError(t, err)
var got *aghuser.Session
got, err = ds.FindByToken(ctx, sessionFirst.Token)
require.NoError(t, err)
assert.Nil(t, got)
got, err = ds.FindByToken(ctx, sessionSecond.Token)
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, userLoginSecond, got.UserLogin)
err = ds.DeleteByToken(ctx, sessionSecond.Token)
require.NoError(t, err)
got, err = ds.FindByToken(ctx, sessionSecond.Token)
require.NoError(t, err)
assert.Nil(t, got)
}))
require.True(t, t.Run("expired_session", func(t *testing.T) {
testutil.CleanupAndRequireSuccess(t, ds.Close)
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
date = date.Add(time.Hour)
var got *aghuser.Session
got, err = ds.FindByToken(ctx, sessionFirst.Token)
require.NoError(t, err)
assert.Nil(t, got)
}))
}

View File

@@ -32,13 +32,13 @@ func MustNewUserID() (uid UserID) {
// User represents a web user.
type User struct {
// ID is the unique identifier for the web user. It must not be empty.
ID UserID
// Password stores the password information for the web user. It must not
// be nil.
Password Password
// Login is the login name of the web user. It must not be empty.
Login Login
// Password stores the password information for the web user. It must not
// be nil.
Password Password
// ID is the unique identifier for the web user. It must not be empty.
ID UserID
}

View File

@@ -1,317 +1,131 @@
package home
import (
"crypto/rand"
"encoding/binary"
"context"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"go.etcd.io/bbolt"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/crypto/bcrypt"
)
// sessionTokenSize is the length of session token in bytes.
const sessionTokenSize = 16
// webUser represents a user of the Web UI.
type webUser struct {
// Name represents the login name of the web user.
Name string `yaml:"name"`
type session struct {
userName string
// expire is the expiration time, in seconds.
expire uint32
// PasswordHash is the hashed representation of the web user password.
PasswordHash string `yaml:"password"`
// UserID is the unique identifier of the web user.
//
// TODO(s.chzhen): !! Use this.
UserID aghuser.UserID `yaml:"-"`
}
func (s *session) serialize() []byte {
const (
expireLen = 4
nameLen = 2
)
data := make([]byte, expireLen+nameLen+len(s.userName))
binary.BigEndian.PutUint32(data[0:4], s.expire)
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
copy(data[6:], []byte(s.userName))
return data
}
func (s *session) deserialize(data []byte) bool {
if len(data) < 4+2 {
return false
// toUser returns the new properly initialized *aghuser.User using stored
// properties. It panics if there is an error generating the user ID.
func (wu *webUser) toUser() (u *aghuser.User) {
uid := wu.UserID
if uid == (aghuser.UserID{}) {
uid = aghuser.MustNewUserID()
}
s.expire = binary.BigEndian.Uint32(data[0:4])
nameLen := binary.BigEndian.Uint16(data[4:6])
data = data[6:]
if len(data) < int(nameLen) {
return false
return &aghuser.User{
Password: aghuser.NewDefaultPassword(wu.PasswordHash),
Login: aghuser.Login(wu.Name),
ID: uid,
}
s.userName = string(data)
return true
}
// Auth is the global authentication object.
type Auth struct {
trustedProxies netutil.SubnetSet
db *bbolt.DB
logger *slog.Logger
rateLimiter *authRateLimiter
sessions map[string]*session
users []webUser
lock sync.Mutex
sessionTTL uint32
sessions aghuser.SessionStorage
trustedProxies netutil.SubnetSet
users aghuser.DB
}
// webUser represents a user of the Web UI.
//
// TODO(s.chzhen): Improve naming.
type webUser struct {
Name string `yaml:"name"`
PasswordHash string `yaml:"password"`
}
// InitAuth initializes the global authentication object.
// InitAuth initializes the global authentication object. baseLogger,
// rateLimiter, trustedProxies must not be nil. dbFilename and sessionTTL
// should not be empty.
func InitAuth(
ctx context.Context,
baseLogger *slog.Logger,
dbFilename string,
users []webUser,
sessionTTL uint32,
sessionTTL time.Duration,
rateLimiter *authRateLimiter,
trustedProxies netutil.SubnetSet,
) (a *Auth) {
log.Info("Initializing auth module: %s", dbFilename)
a = &Auth{
sessionTTL: sessionTTL,
rateLimiter: rateLimiter,
sessions: make(map[string]*session),
users: users,
trustedProxies: trustedProxies,
}
var err error
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
if err != nil {
log.Error("auth: open DB: %s: %s", dbFilename, err)
if err.Error() == "invalid argument" {
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
) (a *Auth, err error) {
userDB := aghuser.NewDefaultDB()
for i, u := range users {
err = userDB.Create(ctx, u.toUser())
if err != nil {
return nil, fmt.Errorf("users: at index %d: %w", i, err)
}
return nil
}
a.loadSessions()
log.Info("auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
return a
s, err := aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
Logger: baseLogger.With(slogutil.KeyPrefix, "session_storage"),
Clock: timeutil.SystemClock{},
UserDB: aghuser.NewDefaultDB(),
DBPath: dbFilename,
SessionTTL: sessionTTL,
})
if err != nil {
return nil, fmt.Errorf("creating session storage: %w", err)
}
return &Auth{
logger: baseLogger.With(slogutil.KeyPrefix, "auth"),
rateLimiter: rateLimiter,
trustedProxies: trustedProxies,
sessions: s,
users: userDB,
}, nil
}
// Close closes the authentication database.
func (a *Auth) Close() {
_ = a.db.Close()
}
func bucketName() []byte {
return []byte("sessions-2")
}
// loadSessions loads sessions from the database file and removes expired
// sessions.
func (a *Auth) loadSessions() {
tx, err := a.db.Begin(true)
func (a *Auth) Close(ctx context.Context) {
err := a.sessions.Close()
if err != nil {
log.Error("auth: bbolt.Begin: %s", err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
return
}
removed := 0
if tx.Bucket([]byte("sessions")) != nil {
_ = tx.DeleteBucket([]byte("sessions"))
removed = 1
}
now := uint32(time.Now().UTC().Unix())
forEach := func(k, v []byte) error {
s := session{}
if !s.deserialize(v) || s.expire <= now {
err = bkt.Delete(k)
if err != nil {
log.Error("auth: bbolt.Delete: %s", err)
} else {
removed++
}
return nil
}
a.sessions[hex.EncodeToString(k)] = &s
return nil
}
_ = bkt.ForEach(forEach)
if removed != 0 {
err = tx.Commit()
if err != nil {
log.Error("bolt.Commit(): %s", err)
}
}
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
}
// addSession adds a new session to the list of sessions and saves it in the
// database file.
func (a *Auth) addSession(data []byte, s *session) {
name := hex.EncodeToString(data)
a.lock.Lock()
a.sessions[name] = s
a.lock.Unlock()
if a.storeSession(data, s) {
log.Debug("auth: created session %s: expire=%d", name, s.expire)
a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)
}
}
// storeSession saves a session in the database file.
func (a *Auth) storeSession(data []byte, s *session) bool {
tx, err := a.db.Begin(true)
// isValidSession returns true if the session is valid.
func (a *Auth) isValidSession(ctx context.Context, cookieSess string) (ok bool) {
sess, err := hex.DecodeString(cookieSess)
if err != nil {
log.Error("auth: bbolt.Begin: %s", err)
return false
}
defer func() {
_ = tx.Rollback()
}()
bkt, err := tx.CreateBucketIfNotExists(bucketName())
if err != nil {
log.Error("auth: bbolt.CreateBucketIfNotExists: %s", err)
a.logger.ErrorContext(ctx, "checking session: decoding cookie", slogutil.KeyError, err)
return false
}
err = bkt.Put(data, s.serialize())
var t aghuser.SessionToken
copy(t[:], sess)
s, err := a.sessions.FindByToken(ctx, t)
if err != nil {
log.Error("auth: bbolt.Put: %s", err)
a.logger.ErrorContext(ctx, "checking session", slogutil.KeyError, err)
return false
}
err = tx.Commit()
if err != nil {
log.Error("auth: bbolt.Commit: %s", err)
return false
}
return true
return s != nil
}
// removeSessionFromFile removes a stored session from the DB file on disk.
func (a *Auth) removeSessionFromFile(sess []byte) {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("auth: bbolt.Begin: %s", err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
log.Error("auth: bbolt.Bucket")
return
}
err = bkt.Delete(sess)
if err != nil {
log.Error("auth: bbolt.Put: %s", err)
return
}
err = tx.Commit()
if err != nil {
log.Error("auth: bbolt.Commit: %s", err)
return
}
log.Debug("auth: removed session from DB")
}
// checkSessionResult is the result of checking a session.
type checkSessionResult int
// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = -1
checkSessionExpired checkSessionResult = 1
)
// checkSession checks if the session is valid.
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
now := uint32(time.Now().UTC().Unix())
update := false
a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[sess]
if !ok {
return checkSessionNotFound
}
if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSessionFromFile(key)
return checkSessionExpired
}
newExpire := now + a.sessionTTL
if s.expire/(24*60*60) != newExpire/(24*60*60) {
// update expiration time once a day
update = true
s.expire = newExpire
}
if update {
key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) {
log.Debug("auth: updated session %s: expire=%d", sess, s.expire)
}
}
return checkSessionOK
}
// removeSession removes the session from the active sessions and the disk.
func (a *Auth) removeSession(sess string) {
key, _ := hex.DecodeString(sess)
a.lock.Lock()
delete(a.sessions, sess)
a.lock.Unlock()
a.removeSessionFromFile(key)
}
// addUser adds a new user with the given password.
func (a *Auth) addUser(u *webUser, password string) (err error) {
// addUser adds a new user with the given password. u must not be nil.
func (a *Auth) addUser(ctx context.Context, u *webUser, password string) (err error) {
if len(password) == 0 {
return errors.Error("empty password")
}
@@ -323,97 +137,129 @@ func (a *Auth) addUser(u *webUser, password string) (err error) {
u.PasswordHash = string(hash)
a.lock.Lock()
defer a.lock.Unlock()
err = a.users.Create(ctx, u.toUser())
if err != nil {
// Should not happen.
panic(err)
}
a.users = append(a.users, *u)
log.Debug("auth: added user with login %q", u.Name)
a.logger.DebugContext(ctx, "added user", "login", u.Name)
return nil
}
// findUser returns a user if there is one.
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
a.lock.Lock()
defer a.lock.Unlock()
for _, u = range a.users {
if u.Name == login &&
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
return u, true
}
// findUser returns a user if one exists with the provided login and the
// password matches.
func (a *Auth) findUser(ctx context.Context, login, password string) (user *aghuser.User) {
user, err := a.users.ByLogin(ctx, aghuser.Login(login))
if err != nil {
return nil
}
return webUser{}, false
ok := user.Password.Authenticate(ctx, password)
if !ok {
return nil
}
return user
}
// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
// getCurrentUser searches for a user using a cookie or credentials from basic
// authentication.
func (a *Auth) getCurrentUser(r *http.Request) (user *aghuser.User) {
ctx := r.Context()
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u, _ = globalContext.auth.findUser(user, pass)
return u
return a.findUser(ctx, user, pass)
}
return webUser{}
return nil
}
a.lock.Lock()
defer a.lock.Unlock()
sess, err := hex.DecodeString(cookie.Value)
if err != nil {
a.logger.ErrorContext(
ctx,
"searching for user: decoding cookie value",
slogutil.KeyError, err,
)
s, ok := a.sessions[cookie.Value]
if !ok {
return webUser{}
return nil
}
for _, u = range a.users {
if u.Name == s.userName {
return u
}
var t aghuser.SessionToken
copy(t[:], sess)
s, err := a.sessions.FindByToken(ctx, t)
if err != nil {
a.logger.ErrorContext(ctx, "searching for user", slogutil.KeyError, err)
return nil
}
return webUser{}
if s == nil {
return nil
}
return &aghuser.User{
Login: s.UserLogin,
ID: s.UserID,
}
}
// removeSession deletes the session from the active sessions and the disk. It
// also logs any occurring errors.
func (a *Auth) removeSession(ctx context.Context, cookieSess string) {
sess, err := hex.DecodeString(cookieSess)
if err != nil {
a.logger.ErrorContext(ctx, "removing session: decoding cookie", slogutil.KeyError, err)
return
}
var t aghuser.SessionToken
copy(t[:], sess)
err = a.sessions.DeleteByToken(ctx, t)
if err != nil {
a.logger.ErrorContext(ctx, "removing session by token", slogutil.KeyError, err)
}
}
// usersList returns a copy of a users list.
func (a *Auth) usersList() (users []webUser) {
a.lock.Lock()
defer a.lock.Unlock()
func (a *Auth) usersList(ctx context.Context) (webUsers []webUser) {
users, err := a.users.All(ctx)
if err != nil {
// Should not happen.
panic(err)
}
users = make([]webUser, len(a.users))
copy(users, a.users)
webUsers = make([]webUser, 0, len(users))
for _, u := range users {
webUsers = append(webUsers, webUser{
Name: string(u.Login),
PasswordHash: string(u.Password.Hash()),
UserID: u.ID,
})
}
return users
return webUsers
}
// authRequired returns true if a authentication is required.
func (a *Auth) authRequired() bool {
func (a *Auth) authRequired(ctx context.Context) (ok bool) {
if GLMode {
return true
}
a.lock.Lock()
defer a.lock.Unlock()
users, err := a.users.All(ctx)
if err != nil {
// Should not happen.
panic(err)
}
return len(a.users) != 0
}
// newSessionToken returns cryptographically secure randomly generated slice of
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken() (data []byte) {
randData := make([]byte, sessionTokenSize)
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
// unrecoverably instead.
_, _ = rand.Read(randData)
return randData
return len(users) != 0
}

View File

@@ -1,69 +0,0 @@
package home
import (
"encoding/hex"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []webUser{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60, nil, nil)
s := session{}
user := webUser{Name: "name"}
err := a.addUser(&user, "password")
require.NoError(t, err)
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.removeSession("notfound")
sess := newSessionToken()
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
a.Close()
// load saved session
a = InitAuth(fn, users, 60, nil, nil)
// the session is still alive
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
u, ok := a.findUser("name", "password")
assert.True(t, ok)
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
// load and remove expired sessions
a = InitAuth(fn, users, 60, nil, nil)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
a.Close()
}

View File

@@ -1,6 +1,7 @@
package home
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
@@ -32,10 +33,14 @@ type loginJSON struct {
}
// newCookie creates a new authentication cookie.
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
func (a *Auth) newCookie(
ctx context.Context,
req loginJSON,
addr string,
) (c *http.Cookie, err error) {
rateLimiter := a.rateLimiter
u, ok := a.findUser(req.Name, req.Password)
if !ok {
u := a.findUser(ctx, req.Name, req.Password)
if u == nil {
if rateLimiter != nil {
rateLimiter.inc(addr)
}
@@ -47,19 +52,16 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
rateLimiter.remove(addr)
}
sess := newSessionToken()
now := time.Now().UTC()
a.addSession(sess, &session{
userName: u.Name,
expire: uint32(now.Unix()) + a.sessionTTL,
})
s, err := a.sessions.New(ctx, u)
if err != nil {
return nil, fmt.Errorf("creating session: %w", err)
}
return &http.Cookie{
Name: sessionCookieName,
Value: hex.EncodeToString(sess),
Value: hex.EncodeToString(s.Token[:]),
Path: "/",
Expires: now.Add(cookieTTL),
Expires: time.Now().Add(cookieTTL),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -172,7 +174,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
}
cookie, err := globalContext.auth.newCookie(req, remoteIP)
cookie, err := globalContext.auth.newCookie(r.Context(), req, remoteIP)
if err != nil {
logIP := remoteIP
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
@@ -209,7 +211,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
return
}
globalContext.auth.removeSession(c.Value)
globalContext.auth.removeSession(r.Context(), c.Value)
c = &http.Cookie{
Name: sessionCookieName,
@@ -242,28 +244,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
return false
}
// redirect to login page if not authenticated
isAuthenticated := false
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
// Check Basic authentication.
user, pass, hasBasic := r.BasicAuth()
if hasBasic {
_, isAuthenticated = globalContext.auth.findUser(user, pass)
if !isAuthenticated {
log.Info("%s: invalid basic authorization value", pref)
}
}
} else {
res := globalContext.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK
if !isAuthenticated {
log.Debug("%s: invalid cookie value: %q", pref, cookie)
}
}
if isAuthenticated {
if u := globalContext.auth.getCurrentUser(r); u != nil {
return false
}
@@ -289,14 +270,14 @@ func optionalAuth(
h func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p := r.URL.Path
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
authRequired := globalContext.auth != nil && globalContext.auth.authRequired(ctx)
if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
// Redirect to the dashboard if already authenticated.
res := globalContext.auth.checkSession(cookie.Value)
if res == checkSessionOK {
if globalContext.auth.isValidSession(ctx, cookie.Value) {
http.Redirect(w, r, "", http.StatusFound)
return

View File

@@ -7,8 +7,10 @@ import (
"net/url"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,13 +35,20 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
}
func TestAuthHTTP(t *testing.T) {
var (
ctx = testutil.ContextWithTimeout(t, testTimeout)
logger = slogutil.NewDiscardLogger()
err error
)
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
globalContext.auth, err = InitAuth(ctx, logger, fn, users, time.Minute, nil, nil)
require.NoError(t, err)
handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) {
@@ -68,7 +77,11 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
// perform login
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
cookie, err := globalContext.auth.newCookie(
ctx,
loginJSON{Name: "name", Password: "password"},
"",
)
require.NoError(t, err)
require.NotNil(t, cookie)
@@ -114,7 +127,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
globalContext.auth.Close()
globalContext.auth.Close(ctx)
}
func TestRealIP(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package home
import (
"bytes"
"context"
"fmt"
"net/netip"
"os"
@@ -748,7 +749,8 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) {
defer c.Unlock()
if globalContext.auth != nil {
config.Users = globalContext.auth.usersList()
// TODO(s.chzhen): Pass context.
config.Users = globalContext.auth.usersList(context.TODO())
}
if tlsMgr != nil {

View File

@@ -392,6 +392,8 @@ const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -439,7 +441,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
u := &webUser{
Name: req.Username,
}
err = globalContext.auth.addUser(u, req.Password)
err = globalContext.auth.addUser(ctx, u, req.Password)
if err != nil {
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
@@ -452,7 +454,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server.
err = startMods(r.Context(), web.baseLogger, web.tlsManager)
err = startMods(ctx, web.baseLogger, web.tlsManager)
if err != nil {
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
@@ -488,11 +490,11 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// and with its own context, because it waits until all requests are handled
// and will be blocked by it's own caller.
go func(timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer slogutil.RecoverAndLog(ctx, web.logger)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
defer cancel()
shutdownSrv(ctx, web.logger, web.httpServer)
shutdownSrv(shutdownCtx, web.logger, web.httpServer)
}(shutdownTimeout)
}

View File

@@ -317,13 +317,7 @@ func newDNSTLSConfig(
return &dnsforward.TLSConfig{}, nil
}
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
if err != nil {
return nil, fmt.Errorf("parsing tls key pair: %w", err)
}
dnsConf = &dnsforward.TLSConfig{
Cert: &cert,
ServerName: conf.ServerName,
StrictSNICheck: conf.StrictSNICheck,
}
@@ -340,6 +334,28 @@ func newDNSTLSConfig(
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC)
}
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
if err != nil {
const format = "parsing tls key pair: %w"
if conf.AllowUnencryptedDoH {
// TODO(s.chzhen): Use [slog.Logger].
log.Info("warning: %s: %s", format, err)
return dnsConf, nil
}
return nil, fmt.Errorf(format, err)
}
// Unencrypted DoH is managed by AdGuard Home itself, not by dnsproxy.
// Therefore, avoid setting the certificate property to prevent dnsproxy
// from starting encrypted listeners. See [dnsforward.Server.prepareTLS].
if conf.AllowUnencryptedDoH {
return dnsConf, nil
}
dnsConf.Cert = &cert
return dnsConf, nil
}

View File

@@ -668,7 +668,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
GLMode = opts.glinetMode
// Init auth module.
globalContext.auth, err = initUsers()
globalContext.auth, err = initUsers(ctx, slogLogger)
fatalOnError(err)
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
@@ -786,7 +786,8 @@ func checkPermissions(
}
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
// baseLogger must not be nil.
func initUsers(ctx context.Context, baseLogger *slog.Logger) (auth *Auth, err error) {
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter
@@ -799,10 +800,17 @@ func initUsers() (auth *Auth, err error) {
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))
sessionTTL := time.Duration(config.HTTPConfig.SessionTTL).Seconds()
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies)
if auth == nil {
return nil, errors.Error("initializing auth module failed")
auth, err = InitAuth(
ctx,
baseLogger,
sessFilename,
config.Users,
time.Duration(config.HTTPConfig.SessionTTL),
rateLimiter,
trustedProxies,
)
if err != nil {
return nil, fmt.Errorf("initializing auth module: %w", err)
}
config.Users = nil
@@ -916,7 +924,7 @@ func cleanup(ctx context.Context) {
globalContext.web = nil
}
if globalContext.auth != nil {
globalContext.auth.Close()
globalContext.auth.Close(ctx)
globalContext.auth = nil
}

View File

@@ -47,7 +47,11 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
name := ""
u := globalContext.auth.getCurrentUser(r)
if u != nil {
name = string(u.Login)
}
var resp profileJSON
func() {
@@ -55,7 +59,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
defer config.RUnlock()
resp = profileJSON{
Name: u.Name,
Name: name,
Language: config.Language,
Theme: config.Theme,
}

View File

@@ -199,6 +199,7 @@ run_linter gocognit --over='10' \
./internal/aghhttp/ \
./internal/aghrenameio/ \
./internal/aghtest/ \
./internal/aghuser/ \
./internal/arpdb/ \
./internal/client/ \
./internal/configmigrate/ \
@@ -250,6 +251,7 @@ run_linter fieldalignment \
./internal/aghrenameio/ \
./internal/aghtest/ \
./internal/aghtls/ \
./internal/aghuser/ \
./internal/arpdb/ \
./internal/client/ \
./internal/configmigrate/ \
@@ -280,6 +282,7 @@ run_linter gosec --exclude G115 --quiet \
./internal/aghos/ \
./internal/aghrenameio/ \
./internal/aghtest/ \
./internal/aghuser/ \
./internal/arpdb/ \
./internal/client/ \
./internal/configmigrate/ \