all: session storage usage
This commit is contained in:
@@ -1,317 +1,131 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"go.etcd.io/bbolt"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// sessionTokenSize is the length of session token in bytes.
|
||||
const sessionTokenSize = 16
|
||||
// webUser represents a user of the Web UI.
|
||||
type webUser struct {
|
||||
// Name represents the login name of the web user.
|
||||
Name string `yaml:"name"`
|
||||
|
||||
type session struct {
|
||||
userName string
|
||||
// expire is the expiration time, in seconds.
|
||||
expire uint32
|
||||
// PasswordHash is the hashed representation of the web user password.
|
||||
PasswordHash string `yaml:"password"`
|
||||
|
||||
// UserID is the unique identifier of the web user.
|
||||
//
|
||||
// TODO(s.chzhen): !! Use this.
|
||||
UserID aghuser.UserID `yaml:"-"`
|
||||
}
|
||||
|
||||
func (s *session) serialize() []byte {
|
||||
const (
|
||||
expireLen = 4
|
||||
nameLen = 2
|
||||
)
|
||||
data := make([]byte, expireLen+nameLen+len(s.userName))
|
||||
binary.BigEndian.PutUint32(data[0:4], s.expire)
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
|
||||
copy(data[6:], []byte(s.userName))
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *session) deserialize(data []byte) bool {
|
||||
if len(data) < 4+2 {
|
||||
return false
|
||||
// toUser returns the new properly initialized *aghuser.User using stored
|
||||
// properties. It panics if there is an error generating the user ID.
|
||||
func (wu *webUser) toUser() (u *aghuser.User) {
|
||||
uid := wu.UserID
|
||||
if uid == (aghuser.UserID{}) {
|
||||
uid = aghuser.MustNewUserID()
|
||||
}
|
||||
s.expire = binary.BigEndian.Uint32(data[0:4])
|
||||
nameLen := binary.BigEndian.Uint16(data[4:6])
|
||||
data = data[6:]
|
||||
|
||||
if len(data) < int(nameLen) {
|
||||
return false
|
||||
return &aghuser.User{
|
||||
Password: aghuser.NewDefaultPassword(wu.PasswordHash),
|
||||
Login: aghuser.Login(wu.Name),
|
||||
ID: uid,
|
||||
}
|
||||
s.userName = string(data)
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth is the global authentication object.
|
||||
type Auth struct {
|
||||
trustedProxies netutil.SubnetSet
|
||||
db *bbolt.DB
|
||||
logger *slog.Logger
|
||||
rateLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
sessions aghuser.SessionStorage
|
||||
trustedProxies netutil.SubnetSet
|
||||
users aghuser.DB
|
||||
}
|
||||
|
||||
// webUser represents a user of the Web UI.
|
||||
//
|
||||
// TODO(s.chzhen): Improve naming.
|
||||
type webUser struct {
|
||||
Name string `yaml:"name"`
|
||||
PasswordHash string `yaml:"password"`
|
||||
}
|
||||
|
||||
// InitAuth initializes the global authentication object.
|
||||
// InitAuth initializes the global authentication object. baseLogger,
|
||||
// rateLimiter, trustedProxies must not be nil. dbFilename and sessionTTL
|
||||
// should not be empty.
|
||||
func InitAuth(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
dbFilename string,
|
||||
users []webUser,
|
||||
sessionTTL uint32,
|
||||
sessionTTL time.Duration,
|
||||
rateLimiter *authRateLimiter,
|
||||
trustedProxies netutil.SubnetSet,
|
||||
) (a *Auth) {
|
||||
log.Info("Initializing auth module: %s", dbFilename)
|
||||
|
||||
a = &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
rateLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
var err error
|
||||
|
||||
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
|
||||
if err != nil {
|
||||
log.Error("auth: open DB: %s: %s", dbFilename, err)
|
||||
if err.Error() == "invalid argument" {
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
|
||||
) (a *Auth, err error) {
|
||||
userDB := aghuser.NewDefaultDB()
|
||||
for i, u := range users {
|
||||
err = userDB.Create(ctx, u.toUser())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("users: at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
a.loadSessions()
|
||||
log.Info("auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
|
||||
|
||||
return a
|
||||
s, err := aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "session_storage"),
|
||||
Clock: timeutil.SystemClock{},
|
||||
UserDB: aghuser.NewDefaultDB(),
|
||||
DBPath: dbFilename,
|
||||
SessionTTL: sessionTTL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating session storage: %w", err)
|
||||
}
|
||||
|
||||
return &Auth{
|
||||
logger: baseLogger.With(slogutil.KeyPrefix, "auth"),
|
||||
rateLimiter: rateLimiter,
|
||||
trustedProxies: trustedProxies,
|
||||
sessions: s,
|
||||
users: userDB,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the authentication database.
|
||||
func (a *Auth) Close() {
|
||||
_ = a.db.Close()
|
||||
}
|
||||
|
||||
func bucketName() []byte {
|
||||
return []byte("sessions-2")
|
||||
}
|
||||
|
||||
// loadSessions loads sessions from the database file and removes expired
|
||||
// sessions.
|
||||
func (a *Auth) loadSessions() {
|
||||
tx, err := a.db.Begin(true)
|
||||
func (a *Auth) Close(ctx context.Context) {
|
||||
err := a.sessions.Close()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
if tx.Bucket([]byte("sessions")) != nil {
|
||||
_ = tx.DeleteBucket([]byte("sessions"))
|
||||
removed = 1
|
||||
}
|
||||
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
forEach := func(k, v []byte) error {
|
||||
s := session{}
|
||||
if !s.deserialize(v) || s.expire <= now {
|
||||
err = bkt.Delete(k)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Delete: %s", err)
|
||||
} else {
|
||||
removed++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
a.sessions[hex.EncodeToString(k)] = &s
|
||||
return nil
|
||||
}
|
||||
_ = bkt.ForEach(forEach)
|
||||
if removed != 0 {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("bolt.Commit(): %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
|
||||
}
|
||||
|
||||
// addSession adds a new session to the list of sessions and saves it in the
|
||||
// database file.
|
||||
func (a *Auth) addSession(data []byte, s *session) {
|
||||
name := hex.EncodeToString(data)
|
||||
a.lock.Lock()
|
||||
a.sessions[name] = s
|
||||
a.lock.Unlock()
|
||||
if a.storeSession(data, s) {
|
||||
log.Debug("auth: created session %s: expire=%d", name, s.expire)
|
||||
a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// storeSession saves a session in the database file.
|
||||
func (a *Auth) storeSession(data []byte, s *session) bool {
|
||||
tx, err := a.db.Begin(true)
|
||||
// isValidSession returns true if the session is valid.
|
||||
func (a *Auth) isValidSession(ctx context.Context, cookieSess string) (ok bool) {
|
||||
sess, err := hex.DecodeString(cookieSess)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists(bucketName())
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.CreateBucketIfNotExists: %s", err)
|
||||
a.logger.ErrorContext(ctx, "checking session: decoding cookie", slogutil.KeyError, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
err = bkt.Put(data, s.serialize())
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
s, err := a.sessions.FindByToken(ctx, t)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Put: %s", err)
|
||||
a.logger.ErrorContext(ctx, "checking session", slogutil.KeyError, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Commit: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return s != nil
|
||||
}
|
||||
|
||||
// removeSessionFromFile removes a stored session from the DB file on disk.
|
||||
func (a *Auth) removeSessionFromFile(sess []byte) {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
log.Error("auth: bbolt.Bucket")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = bkt.Delete(sess)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Put: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Commit: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("auth: removed session from DB")
|
||||
}
|
||||
|
||||
// checkSessionResult is the result of checking a session.
|
||||
type checkSessionResult int
|
||||
|
||||
// checkSessionResult constants.
|
||||
const (
|
||||
checkSessionOK checkSessionResult = 0
|
||||
checkSessionNotFound checkSessionResult = -1
|
||||
checkSessionExpired checkSessionResult = 1
|
||||
)
|
||||
|
||||
// checkSession checks if the session is valid.
|
||||
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
return checkSessionNotFound
|
||||
}
|
||||
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSessionFromFile(key)
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
if s.expire/(24*60*60) != newExpire/(24*60*60) {
|
||||
// update expiration time once a day
|
||||
update = true
|
||||
s.expire = newExpire
|
||||
}
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
if a.storeSession(key, s) {
|
||||
log.Debug("auth: updated session %s: expire=%d", sess, s.expire)
|
||||
}
|
||||
}
|
||||
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// removeSession removes the session from the active sessions and the disk.
|
||||
func (a *Auth) removeSession(sess string) {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.lock.Lock()
|
||||
delete(a.sessions, sess)
|
||||
a.lock.Unlock()
|
||||
a.removeSessionFromFile(key)
|
||||
}
|
||||
|
||||
// addUser adds a new user with the given password.
|
||||
func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||
// addUser adds a new user with the given password. u must not be nil.
|
||||
func (a *Auth) addUser(ctx context.Context, u *webUser, password string) (err error) {
|
||||
if len(password) == 0 {
|
||||
return errors.Error("empty password")
|
||||
}
|
||||
@@ -323,97 +137,129 @@ func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||
|
||||
u.PasswordHash = string(hash)
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
err = a.users.Create(ctx, u.toUser())
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
a.users = append(a.users, *u)
|
||||
|
||||
log.Debug("auth: added user with login %q", u.Name)
|
||||
a.logger.DebugContext(ctx, "added user", "login", u.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findUser returns a user if there is one.
|
||||
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
for _, u = range a.users {
|
||||
if u.Name == login &&
|
||||
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
|
||||
return u, true
|
||||
}
|
||||
// findUser returns a user if one exists with the provided login and the
|
||||
// password matches.
|
||||
func (a *Auth) findUser(ctx context.Context, login, password string) (user *aghuser.User) {
|
||||
user, err := a.users.ByLogin(ctx, aghuser.Login(login))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return webUser{}, false
|
||||
ok := user.Password.Authenticate(ctx, password)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
// getCurrentUser returns the current user. It returns an empty User if the
|
||||
// user is not found.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
||||
// getCurrentUser searches for a user using a cookie or credentials from basic
|
||||
// authentication.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) (user *aghuser.User) {
|
||||
ctx := r.Context()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u, _ = globalContext.auth.findUser(user, pass)
|
||||
|
||||
return u
|
||||
return a.findUser(ctx, user, pass)
|
||||
}
|
||||
|
||||
return webUser{}
|
||||
return nil
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
sess, err := hex.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(
|
||||
ctx,
|
||||
"searching for user: decoding cookie value",
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
return webUser{}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, u = range a.users {
|
||||
if u.Name == s.userName {
|
||||
return u
|
||||
}
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
s, err := a.sessions.FindByToken(ctx, t)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "searching for user", slogutil.KeyError, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return webUser{}
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &aghuser.User{
|
||||
Login: s.UserLogin,
|
||||
ID: s.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
// removeSession deletes the session from the active sessions and the disk. It
|
||||
// also logs any occurring errors.
|
||||
func (a *Auth) removeSession(ctx context.Context, cookieSess string) {
|
||||
sess, err := hex.DecodeString(cookieSess)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "removing session: decoding cookie", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
err = a.sessions.DeleteByToken(ctx, t)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "removing session by token", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// usersList returns a copy of a users list.
|
||||
func (a *Auth) usersList() (users []webUser) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
func (a *Auth) usersList(ctx context.Context) (webUsers []webUser) {
|
||||
users, err := a.users.All(ctx)
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
users = make([]webUser, len(a.users))
|
||||
copy(users, a.users)
|
||||
webUsers = make([]webUser, 0, len(users))
|
||||
for _, u := range users {
|
||||
webUsers = append(webUsers, webUser{
|
||||
Name: string(u.Login),
|
||||
PasswordHash: string(u.Password.Hash()),
|
||||
UserID: u.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return users
|
||||
return webUsers
|
||||
}
|
||||
|
||||
// authRequired returns true if a authentication is required.
|
||||
func (a *Auth) authRequired() bool {
|
||||
func (a *Auth) authRequired(ctx context.Context) (ok bool) {
|
||||
if GLMode {
|
||||
return true
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
users, err := a.users.All(ctx)
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return len(a.users) != 0
|
||||
}
|
||||
|
||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
|
||||
// unrecoverably instead.
|
||||
_, _ = rand.Read(randData)
|
||||
|
||||
return randData
|
||||
return len(users) != 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user