home: add tests

This commit is contained in:
Stanislav Chzhen
2025-05-13 22:23:41 +03:00
parent 88706e9cf2
commit 6109e3575f
2 changed files with 237 additions and 1 deletions

View File

@@ -1,6 +1,8 @@
package home
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
@@ -9,11 +11,13 @@ import (
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
// TODO(s.chzhen): !! Add more tests.
@@ -74,7 +78,7 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) {
code: http.StatusOK,
}, {
url: "/control/login",
method: http.MethodGet,
method: http.MethodPost,
code: http.StatusFound,
}, {
url: "/control/logout",
@@ -117,6 +121,235 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) {
}
}
func TestAuth_ServeHTTP(t *testing.T) {
storeGlobals(t)
const (
authNone = iota
authBasic
authCookie
)
const (
testTTL = 60
userName = "name"
userPassword = "password"
)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
require.NoError(t, err)
sessionsDB := filepath.Join(t.TempDir(), "sessions.db")
users := []webUser{{
Name: userName,
PasswordHash: string(passwordHash),
}}
auth := InitAuth(sessionsDB, users, testTTL, nil, nil)
globalContext.auth = auth
mux := http.NewServeMux()
globalContext.mux = mux
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
configModified: func() {},
})
require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false)
require.NoError(t, err)
globalContext.web = web
creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword})
require.NoError(t, err)
r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds))
r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
var loginCookie *http.Cookie
for _, c := range w.Result().Cookies() {
if c.Name == sessionCookieName {
loginCookie = c
}
}
require.NotNil(t, loginCookie)
testCases := []struct {
url string
method string
authMethod int
wantCode int
}{{
url: "/",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusFound,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authBasic,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authCookie,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusFound,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authBasic,
wantCode: http.StatusBadRequest,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authCookie,
wantCode: http.StatusBadRequest,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/update",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
r = httptest.NewRequest(tc.method, tc.url, nil)
switch tc.authMethod {
case authNone:
// Go on.
case authBasic:
r.SetBasicAuth(userName, userPassword)
case authCookie:
r.AddCookie(loginCookie)
default:
panic("unrecognized auth method")
}
h, pattern := mux.Handler(r)
require.NotEmpty(t, pattern)
w = httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, tc.wantCode, w.Code)
})
}
t.Run("logout", func(t *testing.T) {
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/logout", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusFound, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)
})
}
// implements http.ResponseWriter
type testResponseWriter struct {
hdr http.Header

View File

@@ -114,6 +114,7 @@ func TestValidateCertificates(t *testing.T) {
//
// The global variables are:
// - [configuration]
// - [homeContext.auth]
// - [homeContext.clients.storage]
// - [homeContext.dnsServer]
// - [homeContext.firstRun]
@@ -126,6 +127,7 @@ func storeGlobals(tb testing.TB) {
tb.Helper()
prevConfig := config
auth := globalContext.auth
storage := globalContext.clients.storage
dnsServer := globalContext.dnsServer
firstRun := globalContext.firstRun
@@ -134,6 +136,7 @@ func storeGlobals(tb testing.TB) {
tb.Cleanup(func() {
config = prevConfig
globalContext.auth = auth
globalContext.clients.storage = storage
globalContext.dnsServer = dnsServer
globalContext.firstRun = firstRun