Pull request 2386: AGDNS-2743-aghuser-session
Merge in DNS/adguard-home from AGDNS-2743-aghuser-session to master Squashed commit of the following: commit 74fd4bc11eaf784880855fa2c710a747428db146 Merge: 844e865f67d479babaAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 18:14:36 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit 844e865f647efb4de7f057c392894c8f65bab422 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 18 15:18:44 2025 +0300 aghuser: imp fmt commit 584288e0a3ddbe6d7ae31c80c22b8f397cfd0cae Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 17 20:16:54 2025 +0300 aghuser: imp tests commit ea4c8735585f6d30d6dedf2a40a8dd6b07609d07 Merge: c3fd8fe5e3521e8ed9Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 17 20:10:06 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit c3fd8fe5eabaf2022a971197c018e140c254006d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 17 15:23:45 2025 +0300 aghuser: imp tests commit dfd9aba337227a8d3edc6f5a68f3f039afd1ca0b Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Apr 16 21:40:14 2025 +0300 aghuser: imp code commit b6e75223bf7960f3a2e94c1a3ed7cc33539b9806 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Apr 14 21:49:20 2025 +0300 aghuser: imp code commit 56d6f9d478eec399c376992ffb0f1ca5b797986d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Apr 11 16:58:11 2025 +0300 aghuser: user db commit 6fdc2f60bf7f93e72d917abb12af8e4867143b6d Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 14:11:22 2025 +0300 all: upd scripts commit 575946756f3f622360c5feafe3e721eee010e230 Merge: 7e1fac4ec1cc6c00e4Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 14:05:46 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit 7e1fac4ecb1bde0013bca3f6b64e82d81a78c9c3 Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Apr 10 14:05:35 2025 +0300 aghuser: session storage commit acfb040f0bdff501c7304ea100b9faf1c07291ae Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Apr 8 15:54:24 2025 +0300 aghuser: session
This commit is contained in:
@@ -10,7 +10,8 @@ import (
|
|||||||
// Login is the type for web user logins.
|
// Login is the type for web user logins.
|
||||||
type Login string
|
type Login string
|
||||||
|
|
||||||
// NewLogin returns a web user login.
|
// NewLogin returns a web user login. The length of s must not be greater than
|
||||||
|
// [math.MaxUint16].
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): Add more constraints as needed.
|
// TODO(s.chzhen): Add more constraints as needed.
|
||||||
func NewLogin(s string) (l Login, err error) {
|
func NewLogin(s string) (l Login, err error) {
|
||||||
|
|||||||
35
internal/aghuser/session.go
Normal file
35
internal/aghuser/session.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package aghuser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionToken is the type for the web user session token.
|
||||||
|
type SessionToken [16]byte
|
||||||
|
|
||||||
|
// NewSessionToken returns a cryptographically secure randomly generated web
|
||||||
|
// user session token. If an error occurs during random generation, it will
|
||||||
|
// cause the program to crash.
|
||||||
|
func NewSessionToken() (t SessionToken) {
|
||||||
|
_, _ = rand.Read(t[:])
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session represents a web user session.
|
||||||
|
type Session struct {
|
||||||
|
// Expire indicates when the session will expire.
|
||||||
|
Expire time.Time
|
||||||
|
|
||||||
|
// UserLogin is the login of the web user associated with the session.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Remove this field and associate the user by UserID.
|
||||||
|
UserLogin Login
|
||||||
|
|
||||||
|
// Token is the session token.
|
||||||
|
Token SessionToken
|
||||||
|
|
||||||
|
// UserID is the identifier of the web user associated with the session.
|
||||||
|
UserID UserID
|
||||||
|
}
|
||||||
449
internal/aghuser/sessionstorage.go
Normal file
449
internal/aghuser/sessionstorage.go
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
package aghuser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
|
"go.etcd.io/bbolt"
|
||||||
|
berrors "go.etcd.io/bbolt/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionStorage is an interface that defines methods for handling web user
|
||||||
|
// sessions. All methods must be safe for concurrent use.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Add DeleteAll method.
|
||||||
|
type SessionStorage interface {
|
||||||
|
// New creates a new session for the web user.
|
||||||
|
New(ctx context.Context, u *User) (s *Session, err error)
|
||||||
|
|
||||||
|
// FindByToken returns the stored session for the web user based on the session
|
||||||
|
// token.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Consider function signature change to reflect the
|
||||||
|
// in-memory implementation, as it currently always returns nil for error.
|
||||||
|
FindByToken(ctx context.Context, t SessionToken) (s *Session, err error)
|
||||||
|
|
||||||
|
// DeleteByToken removes a stored web user session by the provided token.
|
||||||
|
DeleteByToken(ctx context.Context, t SessionToken) (err error)
|
||||||
|
|
||||||
|
// Close releases the web user sessions database resources.
|
||||||
|
Close() (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSessionStorageConfig represents the web user session storage
|
||||||
|
// configuration structure.
|
||||||
|
type DefaultSessionStorageConfig struct {
|
||||||
|
// Logger is used for logging the operation of the session storage. It must
|
||||||
|
// not be nil.
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
|
// Clock is used to get the current time. It must not be nil.
|
||||||
|
Clock timeutil.Clock
|
||||||
|
|
||||||
|
// UserDB contains the web user information such as ID, login, and password.
|
||||||
|
// It must not be nil.
|
||||||
|
UserDB DB
|
||||||
|
|
||||||
|
// DBPath is the path to the database file where session data is stored. It
|
||||||
|
// must not be empty.
|
||||||
|
DBPath string
|
||||||
|
|
||||||
|
// SessionTTL is the default Time-To-Live duration for web user sessions.
|
||||||
|
// It specifies how long a session should last and is a required field.
|
||||||
|
SessionTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSessionStorage is the default bbolt database implementation of the
|
||||||
|
// [SessionStorage] interface.
|
||||||
|
type DefaultSessionStorage struct {
|
||||||
|
// db is an instance of the bbolt database where web user sessions are
|
||||||
|
// stored by [SessionToken] in the [bucketNameSessions] bucket.
|
||||||
|
db *bbolt.DB
|
||||||
|
|
||||||
|
// logger is used for logging the operation of the session storage.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
|
// mu protects sessions.
|
||||||
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
// clock is used to get the current time.
|
||||||
|
clock timeutil.Clock
|
||||||
|
|
||||||
|
// userDB contains the web user information such as ID, login, and password.
|
||||||
|
userDB DB
|
||||||
|
|
||||||
|
// sessions maps a session token to a web user session.
|
||||||
|
sessions map[SessionToken]*Session
|
||||||
|
|
||||||
|
// sessionTTL is the default Time-To-Live value for web user sessions.
|
||||||
|
sessionTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultSessionStorage returns the new properly initialized
|
||||||
|
// *DefaultSessionStorage.
|
||||||
|
func NewDefaultSessionStorage(
|
||||||
|
ctx context.Context,
|
||||||
|
conf *DefaultSessionStorageConfig,
|
||||||
|
) (ds *DefaultSessionStorage, err error) {
|
||||||
|
ds = &DefaultSessionStorage{
|
||||||
|
clock: conf.Clock,
|
||||||
|
userDB: conf.UserDB,
|
||||||
|
logger: conf.Logger,
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
sessions: map[SessionToken]*Session{},
|
||||||
|
sessionTTL: conf.SessionTTL,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbFilename := conf.DBPath
|
||||||
|
// TODO(s.chzhen): Pass logger with options.
|
||||||
|
ds.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
|
||||||
|
if err != nil {
|
||||||
|
ds.logger.ErrorContext(ctx, "opening db %q: %w", dbFilename, err)
|
||||||
|
if errors.Is(err, berrors.ErrInvalid) {
|
||||||
|
const s = "AdGuard Home cannot be initialized due to an incompatible file system.\n" +
|
||||||
|
"Please read the explanation here: https://adguard-dns.io/kb/adguard-home/getting-started/#limitations"
|
||||||
|
slogutil.PrintLines(ctx, ds.logger, slog.LevelError, "", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ds.loadSessions(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadSessions loads web user sessions from the bbolt database.
|
||||||
|
func (ds *DefaultSessionStorage) loadSessions(ctx context.Context) (err error) {
|
||||||
|
tx, err := ds.db.Begin(true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback := true
|
||||||
|
defer func() {
|
||||||
|
if needRollback {
|
||||||
|
err = errors.WithDeferred(err, tx.Rollback())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bkt := tx.Bucket([]byte(bboltBucketSessions))
|
||||||
|
if bkt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
removed, err := ds.processSessions(ctx, bkt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("processing sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed == 0 {
|
||||||
|
ds.logger.DebugContext(ctx, "loading sessions from db", "stored", len(ds.sessions))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback = false
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ds.logger.DebugContext(
|
||||||
|
ctx,
|
||||||
|
"loading sessions from db",
|
||||||
|
"stored", len(ds.sessions),
|
||||||
|
"removed", removed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processSessions iterates over the sessions bucket and loads or removes
|
||||||
|
// sessions as needed.
|
||||||
|
func (ds *DefaultSessionStorage) processSessions(
|
||||||
|
ctx context.Context,
|
||||||
|
bkt *bbolt.Bucket,
|
||||||
|
) (removed int, err error) {
|
||||||
|
invalidSessions := [][]byte{}
|
||||||
|
|
||||||
|
err = bkt.ForEach(ds.bboltSessionHandler(ctx, &invalidSessions))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("iterating over sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, s := range invalidSessions {
|
||||||
|
if err = bkt.Delete(s); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = errors.Join(errs...); err != nil {
|
||||||
|
return 0, fmt.Errorf("deleting sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(invalidSessions), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bboltSessionHandler returns a function for [bbolt.Bucket.ForEach] that
|
||||||
|
// iterates over stored sessions, deserializes them, and logs any errors
|
||||||
|
// encountered. The returned error is always nil, as these errors are
|
||||||
|
// considered non-critical to stop the iteration process.
|
||||||
|
func (ds *DefaultSessionStorage) bboltSessionHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
invalidSessions *[][]byte,
|
||||||
|
) (fn func(k, v []byte) (err error)) {
|
||||||
|
now := ds.clock.Now()
|
||||||
|
|
||||||
|
return func(k, v []byte) (err error) {
|
||||||
|
s, err := bboltDecode(v)
|
||||||
|
if err != nil {
|
||||||
|
*invalidSessions = append(*invalidSessions, k)
|
||||||
|
ds.logger.DebugContext(ctx, "deserializing session", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.After(s.Expire) {
|
||||||
|
*invalidSessions = append(*invalidSessions, k)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := ds.userDB.ByLogin(ctx, s.UserLogin)
|
||||||
|
if err != nil {
|
||||||
|
// Should not happen, as it currently always returns nil for error.
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u == nil {
|
||||||
|
*invalidSessions = append(*invalidSessions, k)
|
||||||
|
ds.logger.DebugContext(ctx, "no saved user by name", "name", s.UserLogin)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t := SessionToken(k)
|
||||||
|
s.Token = t
|
||||||
|
s.UserID = u.ID
|
||||||
|
ds.sessions[t] = s
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bboltBucketSessions is the name of the bucket storing web user sessions in
|
||||||
|
// the bbolt database.
|
||||||
|
const bboltBucketSessions = "sessions-2"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// bboltSessionExpireLen is the length of the expire field in the binary
|
||||||
|
// entry stored in bbolt.
|
||||||
|
bboltSessionExpireLen = 4
|
||||||
|
|
||||||
|
// bboltSessionNameLen is the length of the name field in the binary entry
|
||||||
|
// stored in bbolt.
|
||||||
|
bboltSessionNameLen = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// bboltDecode deserializes decodes a binary data into a session.
|
||||||
|
func bboltDecode(data []byte) (s *Session, err error) {
|
||||||
|
if len(data) < bboltSessionExpireLen+bboltSessionNameLen {
|
||||||
|
return nil, fmt.Errorf("length of the data is less than expected: got %d", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
expireData := data[:bboltSessionExpireLen]
|
||||||
|
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
|
||||||
|
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
|
||||||
|
|
||||||
|
nameLen := binary.BigEndian.Uint16(nameLenData)
|
||||||
|
if len(nameData) != int(nameLen) {
|
||||||
|
return nil, fmt.Errorf("login: expected length %d, got %d", nameLen, len(nameData))
|
||||||
|
}
|
||||||
|
|
||||||
|
expire := binary.BigEndian.Uint32(expireData)
|
||||||
|
|
||||||
|
return &Session{
|
||||||
|
Expire: time.Unix(int64(expire), 0),
|
||||||
|
UserLogin: Login(nameData),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bboltEncode serializes a session properties into a binary data.
|
||||||
|
func bboltEncode(s *Session) (data []byte) {
|
||||||
|
data = make([]byte, bboltSessionExpireLen+bboltSessionNameLen+len(s.UserLogin))
|
||||||
|
|
||||||
|
expireData := data[:bboltSessionExpireLen]
|
||||||
|
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
|
||||||
|
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
|
||||||
|
|
||||||
|
expire := uint32(s.Expire.Unix())
|
||||||
|
binary.BigEndian.PutUint32(expireData, expire)
|
||||||
|
binary.BigEndian.PutUint16(nameLenData, uint16(len(s.UserLogin)))
|
||||||
|
copy(nameData, []byte(s.UserLogin))
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ SessionStorage = (*DefaultSessionStorage)(nil)
|
||||||
|
|
||||||
|
// New implements the [SessionStorage] interface for *DefaultSessionStorage.
|
||||||
|
func (ds *DefaultSessionStorage) New(ctx context.Context, u *User) (s *Session, err error) {
|
||||||
|
s = &Session{
|
||||||
|
Token: NewSessionToken(),
|
||||||
|
UserID: u.ID,
|
||||||
|
UserLogin: u.Login,
|
||||||
|
Expire: ds.clock.Now().Add(ds.sessionTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ds.store(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("storing session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ds.mu.Lock()
|
||||||
|
defer ds.mu.Unlock()
|
||||||
|
|
||||||
|
ds.sessions[s.Token] = s
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// store saves a web user session in the bbolt database.
|
||||||
|
func (ds *DefaultSessionStorage) store(s *Session) (err error) {
|
||||||
|
tx, err := ds.db.Begin(true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback := true
|
||||||
|
defer func() {
|
||||||
|
if needRollback {
|
||||||
|
err = errors.WithDeferred(err, tx.Rollback())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bkt, err := tx.CreateBucketIfNotExists([]byte(bboltBucketSessions))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating bucket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bkt.Put(s.Token[:], bboltEncode(s))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("putting data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback = false
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByToken implements the [SessionStorage] interface for *DefaultSessionStorage.
|
||||||
|
func (ds *DefaultSessionStorage) FindByToken(ctx context.Context, t SessionToken) (s *Session, err error) {
|
||||||
|
ds.mu.Lock()
|
||||||
|
defer ds.mu.Unlock()
|
||||||
|
|
||||||
|
s, ok := ds.sessions[t]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := ds.clock.Now()
|
||||||
|
if now.After(s.Expire) {
|
||||||
|
err = ds.deleteByToken(ctx, t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("expired session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByToken implements the [SessionStorage] interface for
|
||||||
|
// *DefaultSessionStorage.
|
||||||
|
func (ds *DefaultSessionStorage) DeleteByToken(ctx context.Context, t SessionToken) (err error) {
|
||||||
|
ds.mu.Lock()
|
||||||
|
defer ds.mu.Unlock()
|
||||||
|
|
||||||
|
// Don't wrap the error because it's informative enough as is.
|
||||||
|
return ds.deleteByToken(ctx, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteByToken removes stored session by token. ds.mu is expected to be
|
||||||
|
// locked.
|
||||||
|
func (ds *DefaultSessionStorage) deleteByToken(ctx context.Context, t SessionToken) (err error) {
|
||||||
|
err = ds.remove(ctx, t)
|
||||||
|
if err != nil {
|
||||||
|
ds.logger.ErrorContext(ctx, "deleting session", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(ds.sessions, t)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove deletes a web user session from the bbolt database.
|
||||||
|
func (ds *DefaultSessionStorage) remove(ctx context.Context, t SessionToken) (err error) {
|
||||||
|
tx, err := ds.db.Begin(true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback := true
|
||||||
|
defer func() {
|
||||||
|
if needRollback {
|
||||||
|
err = errors.WithDeferred(err, tx.Rollback())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bkt := tx.Bucket([]byte(bboltBucketSessions))
|
||||||
|
if bkt == nil {
|
||||||
|
return errors.Error("no bucket")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bkt.Delete(t[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("removing data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
needRollback = false
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("committing transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ds.logger.DebugContext(ctx, "removed session from db")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the [SessionStorage] interface for *DefaultSessionStorage.
|
||||||
|
func (ds *DefaultSessionStorage) Close() (err error) {
|
||||||
|
err = ds.db.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("closing db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
162
internal/aghuser/sessionstorage_test.go
Normal file
162
internal/aghuser/sessionstorage_test.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package aghuser_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil/faketime"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// addSession is a helper function that saves and returns a session for a newly
|
||||||
|
// generated [aghuser.User] by login.
|
||||||
|
func addSession(
|
||||||
|
tb testing.TB,
|
||||||
|
ctx context.Context,
|
||||||
|
ds aghuser.SessionStorage,
|
||||||
|
login aghuser.Login,
|
||||||
|
) (s *aghuser.Session) {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
s, err := ds.New(ctx, &aghuser.User{
|
||||||
|
ID: aghuser.MustNewUserID(),
|
||||||
|
Login: login,
|
||||||
|
})
|
||||||
|
require.NoError(tb, err)
|
||||||
|
require.NotNil(tb, s)
|
||||||
|
|
||||||
|
var got *aghuser.Session
|
||||||
|
got, err = ds.FindByToken(ctx, s.Token)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
require.NotNil(tb, got)
|
||||||
|
|
||||||
|
assert.Equal(tb, login, got.UserLogin)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultSessionStorage(t *testing.T) {
|
||||||
|
const (
|
||||||
|
userLoginFirst aghuser.Login = "user_one"
|
||||||
|
userLoginSecond aghuser.Login = "user_two"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
logger = slogutil.NewDiscardLogger()
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionTTL = time.Minute
|
||||||
|
timeStep = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set up a mock clock to test expired sessions. Each call to [clock.Now]
|
||||||
|
// will return the [date] incremented by [timeStep].
|
||||||
|
date := time.Now()
|
||||||
|
clock := &faketime.Clock{
|
||||||
|
OnNow: func() (now time.Time) {
|
||||||
|
date = date.Add(timeStep)
|
||||||
|
|
||||||
|
return date
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dbFile, err := os.CreateTemp(t.TempDir(), "sessions.db")
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, dbFile.Close)
|
||||||
|
|
||||||
|
userDB := aghuser.NewDefaultDB()
|
||||||
|
|
||||||
|
err = userDB.Create(ctx, &aghuser.User{
|
||||||
|
Login: userLoginFirst,
|
||||||
|
ID: aghuser.MustNewUserID(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = userDB.Create(ctx, &aghuser.User{
|
||||||
|
Login: userLoginSecond,
|
||||||
|
ID: aghuser.MustNewUserID(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ds *aghuser.DefaultSessionStorage
|
||||||
|
|
||||||
|
sessionFirst *aghuser.Session
|
||||||
|
sessionSecond *aghuser.Session
|
||||||
|
)
|
||||||
|
|
||||||
|
require.True(t, t.Run("prepare_session_storage", func(t *testing.T) {
|
||||||
|
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||||
|
Clock: clock,
|
||||||
|
UserDB: userDB,
|
||||||
|
Logger: logger,
|
||||||
|
DBPath: dbFile.Name(),
|
||||||
|
SessionTTL: sessionTTL,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
|
||||||
|
|
||||||
|
// Advance time to ensure the first session expires before creating the
|
||||||
|
// second session.
|
||||||
|
date = date.Add(time.Hour)
|
||||||
|
|
||||||
|
sessionSecond = addSession(t, ctx, ds, userLoginSecond)
|
||||||
|
|
||||||
|
err = ds.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}))
|
||||||
|
|
||||||
|
require.True(t, t.Run("load_sessions", func(t *testing.T) {
|
||||||
|
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||||
|
Clock: clock,
|
||||||
|
UserDB: userDB,
|
||||||
|
Logger: logger,
|
||||||
|
DBPath: dbFile.Name(),
|
||||||
|
SessionTTL: sessionTTL,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var got *aghuser.Session
|
||||||
|
got, err = ds.FindByToken(ctx, sessionFirst.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, got)
|
||||||
|
|
||||||
|
got, err = ds.FindByToken(ctx, sessionSecond.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
|
||||||
|
assert.Equal(t, userLoginSecond, got.UserLogin)
|
||||||
|
|
||||||
|
err = ds.DeleteByToken(ctx, sessionSecond.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err = ds.FindByToken(ctx, sessionSecond.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, got)
|
||||||
|
}))
|
||||||
|
|
||||||
|
require.True(t, t.Run("expired_session", func(t *testing.T) {
|
||||||
|
testutil.CleanupAndRequireSuccess(t, ds.Close)
|
||||||
|
|
||||||
|
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
|
||||||
|
|
||||||
|
date = date.Add(time.Hour)
|
||||||
|
|
||||||
|
var got *aghuser.Session
|
||||||
|
got, err = ds.FindByToken(ctx, sessionFirst.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, got)
|
||||||
|
}))
|
||||||
|
}
|
||||||
@@ -32,13 +32,13 @@ func MustNewUserID() (uid UserID) {
|
|||||||
|
|
||||||
// User represents a web user.
|
// User represents a web user.
|
||||||
type User struct {
|
type User struct {
|
||||||
// ID is the unique identifier for the web user. It must not be empty.
|
// Password stores the password information for the web user. It must not
|
||||||
ID UserID
|
// be nil.
|
||||||
|
Password Password
|
||||||
|
|
||||||
// Login is the login name of the web user. It must not be empty.
|
// Login is the login name of the web user. It must not be empty.
|
||||||
Login Login
|
Login Login
|
||||||
|
|
||||||
// Password stores the password information for the web user. It must not
|
// ID is the unique identifier for the web user. It must not be empty.
|
||||||
// be nil.
|
ID UserID
|
||||||
Password Password
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ run_linter gocognit --over='10' \
|
|||||||
./internal/aghhttp/ \
|
./internal/aghhttp/ \
|
||||||
./internal/aghrenameio/ \
|
./internal/aghrenameio/ \
|
||||||
./internal/aghtest/ \
|
./internal/aghtest/ \
|
||||||
|
./internal/aghuser/ \
|
||||||
./internal/arpdb/ \
|
./internal/arpdb/ \
|
||||||
./internal/client/ \
|
./internal/client/ \
|
||||||
./internal/configmigrate/ \
|
./internal/configmigrate/ \
|
||||||
@@ -250,6 +251,7 @@ run_linter fieldalignment \
|
|||||||
./internal/aghrenameio/ \
|
./internal/aghrenameio/ \
|
||||||
./internal/aghtest/ \
|
./internal/aghtest/ \
|
||||||
./internal/aghtls/ \
|
./internal/aghtls/ \
|
||||||
|
./internal/aghuser/ \
|
||||||
./internal/arpdb/ \
|
./internal/arpdb/ \
|
||||||
./internal/client/ \
|
./internal/client/ \
|
||||||
./internal/configmigrate/ \
|
./internal/configmigrate/ \
|
||||||
@@ -280,6 +282,7 @@ run_linter gosec --exclude G115 --quiet \
|
|||||||
./internal/aghos/ \
|
./internal/aghos/ \
|
||||||
./internal/aghrenameio/ \
|
./internal/aghrenameio/ \
|
||||||
./internal/aghtest/ \
|
./internal/aghtest/ \
|
||||||
|
./internal/aghuser/ \
|
||||||
./internal/arpdb/ \
|
./internal/arpdb/ \
|
||||||
./internal/client/ \
|
./internal/client/ \
|
||||||
./internal/configmigrate/ \
|
./internal/configmigrate/ \
|
||||||
|
|||||||
Reference in New Issue
Block a user