Compare commits

...

2 Commits

Author SHA1 Message Date
Stanislav Chzhen
6109e3575f home: add tests 2025-05-13 22:23:41 +03:00
Stanislav Chzhen
88706e9cf2 home: auth tests 2025-05-07 15:46:22 +03:00
2 changed files with 346 additions and 1 deletions

View File

@@ -1,19 +1,355 @@
package home
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"net/textproto"
"net/url"
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
// TODO(s.chzhen): !! Add more tests.
func TestAuth_ServeHTTP_first_run(t *testing.T) {
storeGlobals(t)
globalContext.firstRun = true
mux := http.NewServeMux()
globalContext.mux = mux
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
globalContext.web = web
testCases := []struct {
url string
method string
code int
}{{
url: "/",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/apple/doh.mobileconfig",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/apple/dot.mobileconfig",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/i18n/change_language",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/install/check_config",
method: http.MethodPost,
code: http.StatusBadRequest,
}, {
url: "/control/install/configure",
method: http.MethodPost,
code: http.StatusBadRequest,
}, {
url: "/control/install/get_addresses",
method: http.MethodGet,
code: http.StatusOK,
}, {
url: "/control/login",
method: http.MethodPost,
code: http.StatusFound,
}, {
url: "/control/logout",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/profile",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/profile/update",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/status",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/update",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/version.json",
method: http.MethodGet,
code: http.StatusFound,
}}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
r := httptest.NewRequest(tc.method, tc.url, nil)
h, pattern := mux.Handler(r)
require.NotEmpty(t, pattern)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, tc.code, w.Code)
})
}
}
func TestAuth_ServeHTTP(t *testing.T) {
storeGlobals(t)
const (
authNone = iota
authBasic
authCookie
)
const (
testTTL = 60
userName = "name"
userPassword = "password"
)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
require.NoError(t, err)
sessionsDB := filepath.Join(t.TempDir(), "sessions.db")
users := []webUser{{
Name: userName,
PasswordHash: string(passwordHash),
}}
auth := InitAuth(sessionsDB, users, testTTL, nil, nil)
globalContext.auth = auth
mux := http.NewServeMux()
globalContext.mux = mux
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
configModified: func() {},
})
require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false)
require.NoError(t, err)
globalContext.web = web
creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword})
require.NoError(t, err)
r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds))
r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
var loginCookie *http.Cookie
for _, c := range w.Result().Cookies() {
if c.Name == sessionCookieName {
loginCookie = c
}
}
require.NotNil(t, loginCookie)
testCases := []struct {
url string
method string
authMethod int
wantCode int
}{{
url: "/",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusFound,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authBasic,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authCookie,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusFound,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authBasic,
wantCode: http.StatusBadRequest,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authCookie,
wantCode: http.StatusBadRequest,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/update",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
r = httptest.NewRequest(tc.method, tc.url, nil)
switch tc.authMethod {
case authNone:
// Go on.
case authBasic:
r.SetBasicAuth(userName, userPassword)
case authCookie:
r.AddCookie(loginCookie)
default:
panic("unrecognized auth method")
}
h, pattern := mux.Handler(r)
require.NotEmpty(t, pattern)
w = httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, tc.wantCode, w.Code)
})
}
t.Run("logout", func(t *testing.T) {
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/logout", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusFound, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)
})
}
// implements http.ResponseWriter
type testResponseWriter struct {
hdr http.Header

View File

@@ -113,10 +113,13 @@ func TestValidateCertificates(t *testing.T) {
// restores them once the test is complete.
//
// The global variables are:
// - [configuration.dns]
// - [configuration]
// - [homeContext.auth]
// - [homeContext.clients.storage]
// - [homeContext.dnsServer]
// - [homeContext.firstRun]
// - [homeContext.mux]
// - [homeContext.web]
//
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
// variables. Make tests that use this helper concurrent.
@@ -124,15 +127,21 @@ func storeGlobals(tb testing.TB) {
tb.Helper()
prevConfig := config
auth := globalContext.auth
storage := globalContext.clients.storage
dnsServer := globalContext.dnsServer
firstRun := globalContext.firstRun
mux := globalContext.mux
web := globalContext.web
tb.Cleanup(func() {
config = prevConfig
globalContext.auth = auth
globalContext.clients.storage = storage
globalContext.dnsServer = dnsServer
globalContext.firstRun = firstRun
globalContext.mux = mux
globalContext.web = web
})
}