From a8fdf1c5535dec03a90e252b4ae692bff8e3a569 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 18 Apr 2025 18:34:10 +0300 Subject: [PATCH] Pull request 2386: AGDNS-2743-aghuser-session Merge in DNS/adguard-home from AGDNS-2743-aghuser-session to master Squashed commit of the following: commit 74fd4bc11eaf784880855fa2c710a747428db146 Merge: 844e865f6 7d479baba Author: Stanislav Chzhen Date: Fri Apr 18 18:14:36 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit 844e865f647efb4de7f057c392894c8f65bab422 Author: Stanislav Chzhen Date: Fri Apr 18 15:18:44 2025 +0300 aghuser: imp fmt commit 584288e0a3ddbe6d7ae31c80c22b8f397cfd0cae Author: Stanislav Chzhen Date: Thu Apr 17 20:16:54 2025 +0300 aghuser: imp tests commit ea4c8735585f6d30d6dedf2a40a8dd6b07609d07 Merge: c3fd8fe5e 3521e8ed9 Author: Stanislav Chzhen Date: Thu Apr 17 20:10:06 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit c3fd8fe5eabaf2022a971197c018e140c254006d Author: Stanislav Chzhen Date: Thu Apr 17 15:23:45 2025 +0300 aghuser: imp tests commit dfd9aba337227a8d3edc6f5a68f3f039afd1ca0b Author: Stanislav Chzhen Date: Wed Apr 16 21:40:14 2025 +0300 aghuser: imp code commit b6e75223bf7960f3a2e94c1a3ed7cc33539b9806 Author: Stanislav Chzhen Date: Mon Apr 14 21:49:20 2025 +0300 aghuser: imp code commit 56d6f9d478eec399c376992ffb0f1ca5b797986d Author: Stanislav Chzhen Date: Fri Apr 11 16:58:11 2025 +0300 aghuser: user db commit 6fdc2f60bf7f93e72d917abb12af8e4867143b6d Author: Stanislav Chzhen Date: Thu Apr 10 14:11:22 2025 +0300 all: upd scripts commit 575946756f3f622360c5feafe3e721eee010e230 Merge: 7e1fac4ec 1cc6c00e4 Author: Stanislav Chzhen Date: Thu Apr 10 14:05:46 2025 +0300 Merge branch 'master' into AGDNS-2743-aghuser-session commit 7e1fac4ecb1bde0013bca3f6b64e82d81a78c9c3 Author: Stanislav Chzhen Date: Thu Apr 10 14:05:35 2025 +0300 aghuser: session storage commit acfb040f0bdff501c7304ea100b9faf1c07291ae Author: Stanislav Chzhen Date: Tue Apr 8 15:54:24 2025 +0300 aghuser: session --- internal/aghuser/aghuser.go | 3 +- internal/aghuser/session.go | 35 ++ internal/aghuser/sessionstorage.go | 449 ++++++++++++++++++++++++ internal/aghuser/sessionstorage_test.go | 162 +++++++++ internal/aghuser/user.go | 10 +- scripts/make/go-lint.sh | 3 + 6 files changed, 656 insertions(+), 6 deletions(-) create mode 100644 internal/aghuser/session.go create mode 100644 internal/aghuser/sessionstorage.go create mode 100644 internal/aghuser/sessionstorage_test.go diff --git a/internal/aghuser/aghuser.go b/internal/aghuser/aghuser.go index eed8617b..783a8d3a 100644 --- a/internal/aghuser/aghuser.go +++ b/internal/aghuser/aghuser.go @@ -10,7 +10,8 @@ import ( // Login is the type for web user logins. type Login string -// NewLogin returns a web user login. +// NewLogin returns a web user login. The length of s must not be greater than +// [math.MaxUint16]. // // TODO(s.chzhen): Add more constraints as needed. func NewLogin(s string) (l Login, err error) { diff --git a/internal/aghuser/session.go b/internal/aghuser/session.go new file mode 100644 index 00000000..1bfe9023 --- /dev/null +++ b/internal/aghuser/session.go @@ -0,0 +1,35 @@ +package aghuser + +import ( + "crypto/rand" + "time" +) + +// SessionToken is the type for the web user session token. +type SessionToken [16]byte + +// NewSessionToken returns a cryptographically secure randomly generated web +// user session token. If an error occurs during random generation, it will +// cause the program to crash. +func NewSessionToken() (t SessionToken) { + _, _ = rand.Read(t[:]) + + return t +} + +// Session represents a web user session. +type Session struct { + // Expire indicates when the session will expire. + Expire time.Time + + // UserLogin is the login of the web user associated with the session. + // + // TODO(s.chzhen): Remove this field and associate the user by UserID. + UserLogin Login + + // Token is the session token. + Token SessionToken + + // UserID is the identifier of the web user associated with the session. + UserID UserID +} diff --git a/internal/aghuser/sessionstorage.go b/internal/aghuser/sessionstorage.go new file mode 100644 index 00000000..c4d99277 --- /dev/null +++ b/internal/aghuser/sessionstorage.go @@ -0,0 +1,449 @@ +package aghuser + +import ( + "context" + "encoding/binary" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/timeutil" + "go.etcd.io/bbolt" + berrors "go.etcd.io/bbolt/errors" +) + +// SessionStorage is an interface that defines methods for handling web user +// sessions. All methods must be safe for concurrent use. +// +// TODO(s.chzhen): Add DeleteAll method. +type SessionStorage interface { + // New creates a new session for the web user. + New(ctx context.Context, u *User) (s *Session, err error) + + // FindByToken returns the stored session for the web user based on the session + // token. + // + // TODO(s.chzhen): Consider function signature change to reflect the + // in-memory implementation, as it currently always returns nil for error. + FindByToken(ctx context.Context, t SessionToken) (s *Session, err error) + + // DeleteByToken removes a stored web user session by the provided token. + DeleteByToken(ctx context.Context, t SessionToken) (err error) + + // Close releases the web user sessions database resources. + Close() (err error) +} + +// DefaultSessionStorageConfig represents the web user session storage +// configuration structure. +type DefaultSessionStorageConfig struct { + // Logger is used for logging the operation of the session storage. It must + // not be nil. + Logger *slog.Logger + + // Clock is used to get the current time. It must not be nil. + Clock timeutil.Clock + + // UserDB contains the web user information such as ID, login, and password. + // It must not be nil. + UserDB DB + + // DBPath is the path to the database file where session data is stored. It + // must not be empty. + DBPath string + + // SessionTTL is the default Time-To-Live duration for web user sessions. + // It specifies how long a session should last and is a required field. + SessionTTL time.Duration +} + +// DefaultSessionStorage is the default bbolt database implementation of the +// [SessionStorage] interface. +type DefaultSessionStorage struct { + // db is an instance of the bbolt database where web user sessions are + // stored by [SessionToken] in the [bucketNameSessions] bucket. + db *bbolt.DB + + // logger is used for logging the operation of the session storage. + logger *slog.Logger + + // mu protects sessions. + mu *sync.Mutex + + // clock is used to get the current time. + clock timeutil.Clock + + // userDB contains the web user information such as ID, login, and password. + userDB DB + + // sessions maps a session token to a web user session. + sessions map[SessionToken]*Session + + // sessionTTL is the default Time-To-Live value for web user sessions. + sessionTTL time.Duration +} + +// NewDefaultSessionStorage returns the new properly initialized +// *DefaultSessionStorage. +func NewDefaultSessionStorage( + ctx context.Context, + conf *DefaultSessionStorageConfig, +) (ds *DefaultSessionStorage, err error) { + ds = &DefaultSessionStorage{ + clock: conf.Clock, + userDB: conf.UserDB, + logger: conf.Logger, + mu: &sync.Mutex{}, + sessions: map[SessionToken]*Session{}, + sessionTTL: conf.SessionTTL, + } + + dbFilename := conf.DBPath + // TODO(s.chzhen): Pass logger with options. + ds.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil) + if err != nil { + ds.logger.ErrorContext(ctx, "opening db %q: %w", dbFilename, err) + if errors.Is(err, berrors.ErrInvalid) { + const s = "AdGuard Home cannot be initialized due to an incompatible file system.\n" + + "Please read the explanation here: https://adguard-dns.io/kb/adguard-home/getting-started/#limitations" + slogutil.PrintLines(ctx, ds.logger, slog.LevelError, "", s) + } + + return nil, err + } + + err = ds.loadSessions(ctx) + if err != nil { + return nil, fmt.Errorf("loading sessions: %w", err) + } + + return ds, nil +} + +// loadSessions loads web user sessions from the bbolt database. +func (ds *DefaultSessionStorage) loadSessions(ctx context.Context) (err error) { + tx, err := ds.db.Begin(true) + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + needRollback := true + defer func() { + if needRollback { + err = errors.WithDeferred(err, tx.Rollback()) + } + }() + + bkt := tx.Bucket([]byte(bboltBucketSessions)) + if bkt == nil { + return nil + } + + removed, err := ds.processSessions(ctx, bkt) + if err != nil { + return fmt.Errorf("processing sessions: %w", err) + } + + if removed == 0 { + ds.logger.DebugContext(ctx, "loading sessions from db", "stored", len(ds.sessions)) + + return nil + } + + needRollback = false + err = tx.Commit() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + ds.logger.DebugContext( + ctx, + "loading sessions from db", + "stored", len(ds.sessions), + "removed", removed, + ) + + return nil +} + +// processSessions iterates over the sessions bucket and loads or removes +// sessions as needed. +func (ds *DefaultSessionStorage) processSessions( + ctx context.Context, + bkt *bbolt.Bucket, +) (removed int, err error) { + invalidSessions := [][]byte{} + + err = bkt.ForEach(ds.bboltSessionHandler(ctx, &invalidSessions)) + if err != nil { + return 0, fmt.Errorf("iterating over sessions: %w", err) + } + + var errs []error + for _, s := range invalidSessions { + if err = bkt.Delete(s); err != nil { + errs = append(errs, err) + } + } + + if err = errors.Join(errs...); err != nil { + return 0, fmt.Errorf("deleting sessions: %w", err) + } + + return len(invalidSessions), nil +} + +// bboltSessionHandler returns a function for [bbolt.Bucket.ForEach] that +// iterates over stored sessions, deserializes them, and logs any errors +// encountered. The returned error is always nil, as these errors are +// considered non-critical to stop the iteration process. +func (ds *DefaultSessionStorage) bboltSessionHandler( + ctx context.Context, + invalidSessions *[][]byte, +) (fn func(k, v []byte) (err error)) { + now := ds.clock.Now() + + return func(k, v []byte) (err error) { + s, err := bboltDecode(v) + if err != nil { + *invalidSessions = append(*invalidSessions, k) + ds.logger.DebugContext(ctx, "deserializing session", slogutil.KeyError, err) + + return nil + } + + if now.After(s.Expire) { + *invalidSessions = append(*invalidSessions, k) + + return nil + } + + u, err := ds.userDB.ByLogin(ctx, s.UserLogin) + if err != nil { + // Should not happen, as it currently always returns nil for error. + panic(err) + } + + if u == nil { + *invalidSessions = append(*invalidSessions, k) + ds.logger.DebugContext(ctx, "no saved user by name", "name", s.UserLogin) + + return nil + } + + t := SessionToken(k) + s.Token = t + s.UserID = u.ID + ds.sessions[t] = s + + return nil + } +} + +// bboltBucketSessions is the name of the bucket storing web user sessions in +// the bbolt database. +const bboltBucketSessions = "sessions-2" + +const ( + // bboltSessionExpireLen is the length of the expire field in the binary + // entry stored in bbolt. + bboltSessionExpireLen = 4 + + // bboltSessionNameLen is the length of the name field in the binary entry + // stored in bbolt. + bboltSessionNameLen = 2 +) + +// bboltDecode deserializes decodes a binary data into a session. +func bboltDecode(data []byte) (s *Session, err error) { + if len(data) < bboltSessionExpireLen+bboltSessionNameLen { + return nil, fmt.Errorf("length of the data is less than expected: got %d", len(data)) + } + + expireData := data[:bboltSessionExpireLen] + nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen] + nameData := data[bboltSessionExpireLen+bboltSessionNameLen:] + + nameLen := binary.BigEndian.Uint16(nameLenData) + if len(nameData) != int(nameLen) { + return nil, fmt.Errorf("login: expected length %d, got %d", nameLen, len(nameData)) + } + + expire := binary.BigEndian.Uint32(expireData) + + return &Session{ + Expire: time.Unix(int64(expire), 0), + UserLogin: Login(nameData), + }, nil +} + +// bboltEncode serializes a session properties into a binary data. +func bboltEncode(s *Session) (data []byte) { + data = make([]byte, bboltSessionExpireLen+bboltSessionNameLen+len(s.UserLogin)) + + expireData := data[:bboltSessionExpireLen] + nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen] + nameData := data[bboltSessionExpireLen+bboltSessionNameLen:] + + expire := uint32(s.Expire.Unix()) + binary.BigEndian.PutUint32(expireData, expire) + binary.BigEndian.PutUint16(nameLenData, uint16(len(s.UserLogin))) + copy(nameData, []byte(s.UserLogin)) + + return data +} + +// type check +var _ SessionStorage = (*DefaultSessionStorage)(nil) + +// New implements the [SessionStorage] interface for *DefaultSessionStorage. +func (ds *DefaultSessionStorage) New(ctx context.Context, u *User) (s *Session, err error) { + s = &Session{ + Token: NewSessionToken(), + UserID: u.ID, + UserLogin: u.Login, + Expire: ds.clock.Now().Add(ds.sessionTTL), + } + + err = ds.store(s) + if err != nil { + return nil, fmt.Errorf("storing session: %w", err) + } + + ds.mu.Lock() + defer ds.mu.Unlock() + + ds.sessions[s.Token] = s + + return s, nil +} + +// store saves a web user session in the bbolt database. +func (ds *DefaultSessionStorage) store(s *Session) (err error) { + tx, err := ds.db.Begin(true) + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + needRollback := true + defer func() { + if needRollback { + err = errors.WithDeferred(err, tx.Rollback()) + } + }() + + bkt, err := tx.CreateBucketIfNotExists([]byte(bboltBucketSessions)) + if err != nil { + return fmt.Errorf("creating bucket: %w", err) + } + + err = bkt.Put(s.Token[:], bboltEncode(s)) + if err != nil { + return fmt.Errorf("putting data: %w", err) + } + + needRollback = false + err = tx.Commit() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} + +// FindByToken implements the [SessionStorage] interface for *DefaultSessionStorage. +func (ds *DefaultSessionStorage) FindByToken(ctx context.Context, t SessionToken) (s *Session, err error) { + ds.mu.Lock() + defer ds.mu.Unlock() + + s, ok := ds.sessions[t] + if !ok { + return nil, nil + } + + now := ds.clock.Now() + if now.After(s.Expire) { + err = ds.deleteByToken(ctx, t) + if err != nil { + return nil, fmt.Errorf("expired session: %w", err) + } + + return nil, nil + } + + return s, nil +} + +// DeleteByToken implements the [SessionStorage] interface for +// *DefaultSessionStorage. +func (ds *DefaultSessionStorage) DeleteByToken(ctx context.Context, t SessionToken) (err error) { + ds.mu.Lock() + defer ds.mu.Unlock() + + // Don't wrap the error because it's informative enough as is. + return ds.deleteByToken(ctx, t) +} + +// deleteByToken removes stored session by token. ds.mu is expected to be +// locked. +func (ds *DefaultSessionStorage) deleteByToken(ctx context.Context, t SessionToken) (err error) { + err = ds.remove(ctx, t) + if err != nil { + ds.logger.ErrorContext(ctx, "deleting session", slogutil.KeyError, err) + + return err + } + + delete(ds.sessions, t) + + return nil +} + +// remove deletes a web user session from the bbolt database. +func (ds *DefaultSessionStorage) remove(ctx context.Context, t SessionToken) (err error) { + tx, err := ds.db.Begin(true) + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + needRollback := true + defer func() { + if needRollback { + err = errors.WithDeferred(err, tx.Rollback()) + } + }() + + bkt := tx.Bucket([]byte(bboltBucketSessions)) + if bkt == nil { + return errors.Error("no bucket") + } + + err = bkt.Delete(t[:]) + if err != nil { + return fmt.Errorf("removing data: %w", err) + } + + needRollback = false + err = tx.Commit() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + ds.logger.DebugContext(ctx, "removed session from db") + + return err +} + +// Close implements the [SessionStorage] interface for *DefaultSessionStorage. +func (ds *DefaultSessionStorage) Close() (err error) { + err = ds.db.Close() + if err != nil { + return fmt.Errorf("closing db: %w", err) + } + + return nil +} diff --git a/internal/aghuser/sessionstorage_test.go b/internal/aghuser/sessionstorage_test.go new file mode 100644 index 00000000..f50f7586 --- /dev/null +++ b/internal/aghuser/sessionstorage_test.go @@ -0,0 +1,162 @@ +package aghuser_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghuser" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/testutil/faketime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// addSession is a helper function that saves and returns a session for a newly +// generated [aghuser.User] by login. +func addSession( + tb testing.TB, + ctx context.Context, + ds aghuser.SessionStorage, + login aghuser.Login, +) (s *aghuser.Session) { + tb.Helper() + + s, err := ds.New(ctx, &aghuser.User{ + ID: aghuser.MustNewUserID(), + Login: login, + }) + require.NoError(tb, err) + require.NotNil(tb, s) + + var got *aghuser.Session + got, err = ds.FindByToken(ctx, s.Token) + require.NoError(tb, err) + require.NotNil(tb, got) + + assert.Equal(tb, login, got.UserLogin) + + return s +} + +func TestDefaultSessionStorage(t *testing.T) { + const ( + userLoginFirst aghuser.Login = "user_one" + userLoginSecond aghuser.Login = "user_two" + ) + + var ( + ctx = testutil.ContextWithTimeout(t, testTimeout) + logger = slogutil.NewDiscardLogger() + ) + + const ( + sessionTTL = time.Minute + timeStep = time.Second + ) + + // Set up a mock clock to test expired sessions. Each call to [clock.Now] + // will return the [date] incremented by [timeStep]. + date := time.Now() + clock := &faketime.Clock{ + OnNow: func() (now time.Time) { + date = date.Add(timeStep) + + return date + }, + } + + dbFile, err := os.CreateTemp(t.TempDir(), "sessions.db") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, dbFile.Close) + + userDB := aghuser.NewDefaultDB() + + err = userDB.Create(ctx, &aghuser.User{ + Login: userLoginFirst, + ID: aghuser.MustNewUserID(), + }) + require.NoError(t, err) + + err = userDB.Create(ctx, &aghuser.User{ + Login: userLoginSecond, + ID: aghuser.MustNewUserID(), + }) + require.NoError(t, err) + + var ( + ds *aghuser.DefaultSessionStorage + + sessionFirst *aghuser.Session + sessionSecond *aghuser.Session + ) + + require.True(t, t.Run("prepare_session_storage", func(t *testing.T) { + ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{ + Clock: clock, + UserDB: userDB, + Logger: logger, + DBPath: dbFile.Name(), + SessionTTL: sessionTTL, + }) + require.NoError(t, err) + + sessionFirst = addSession(t, ctx, ds, userLoginFirst) + + // Advance time to ensure the first session expires before creating the + // second session. + date = date.Add(time.Hour) + + sessionSecond = addSession(t, ctx, ds, userLoginSecond) + + err = ds.Close() + require.NoError(t, err) + })) + + require.True(t, t.Run("load_sessions", func(t *testing.T) { + ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{ + Clock: clock, + UserDB: userDB, + Logger: logger, + DBPath: dbFile.Name(), + SessionTTL: sessionTTL, + }) + require.NoError(t, err) + + var got *aghuser.Session + got, err = ds.FindByToken(ctx, sessionFirst.Token) + require.NoError(t, err) + + assert.Nil(t, got) + + got, err = ds.FindByToken(ctx, sessionSecond.Token) + require.NoError(t, err) + require.NotNil(t, got) + + assert.Equal(t, userLoginSecond, got.UserLogin) + + err = ds.DeleteByToken(ctx, sessionSecond.Token) + require.NoError(t, err) + + got, err = ds.FindByToken(ctx, sessionSecond.Token) + require.NoError(t, err) + + assert.Nil(t, got) + })) + + require.True(t, t.Run("expired_session", func(t *testing.T) { + testutil.CleanupAndRequireSuccess(t, ds.Close) + + sessionFirst = addSession(t, ctx, ds, userLoginFirst) + + date = date.Add(time.Hour) + + var got *aghuser.Session + got, err = ds.FindByToken(ctx, sessionFirst.Token) + require.NoError(t, err) + + assert.Nil(t, got) + })) +} diff --git a/internal/aghuser/user.go b/internal/aghuser/user.go index 63dc4345..404c7ea0 100644 --- a/internal/aghuser/user.go +++ b/internal/aghuser/user.go @@ -32,13 +32,13 @@ func MustNewUserID() (uid UserID) { // User represents a web user. type User struct { - // ID is the unique identifier for the web user. It must not be empty. - ID UserID + // Password stores the password information for the web user. It must not + // be nil. + Password Password // Login is the login name of the web user. It must not be empty. Login Login - // Password stores the password information for the web user. It must not - // be nil. - Password Password + // ID is the unique identifier for the web user. It must not be empty. + ID UserID } diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 9a4c5361..84a4e5b8 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -199,6 +199,7 @@ run_linter gocognit --over='10' \ ./internal/aghhttp/ \ ./internal/aghrenameio/ \ ./internal/aghtest/ \ + ./internal/aghuser/ \ ./internal/arpdb/ \ ./internal/client/ \ ./internal/configmigrate/ \ @@ -250,6 +251,7 @@ run_linter fieldalignment \ ./internal/aghrenameio/ \ ./internal/aghtest/ \ ./internal/aghtls/ \ + ./internal/aghuser/ \ ./internal/arpdb/ \ ./internal/client/ \ ./internal/configmigrate/ \ @@ -280,6 +282,7 @@ run_linter gosec --exclude G115 --quiet \ ./internal/aghos/ \ ./internal/aghrenameio/ \ ./internal/aghtest/ \ + ./internal/aghuser/ \ ./internal/arpdb/ \ ./internal/client/ \ ./internal/configmigrate/ \