home: add tests
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
@@ -9,11 +11,13 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// TODO(s.chzhen): !! Add more tests.
|
||||
@@ -74,7 +78,7 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) {
|
||||
code: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/login",
|
||||
method: http.MethodGet,
|
||||
method: http.MethodPost,
|
||||
code: http.StatusFound,
|
||||
}, {
|
||||
url: "/control/logout",
|
||||
@@ -117,6 +121,235 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth_ServeHTTP(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
const (
|
||||
authNone = iota
|
||||
authBasic
|
||||
authCookie
|
||||
)
|
||||
|
||||
const (
|
||||
testTTL = 60
|
||||
userName = "name"
|
||||
userPassword = "password"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = slogutil.NewDiscardLogger()
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err error
|
||||
)
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionsDB := filepath.Join(t.TempDir(), "sessions.db")
|
||||
|
||||
users := []webUser{{
|
||||
Name: userName,
|
||||
PasswordHash: string(passwordHash),
|
||||
}}
|
||||
auth := InitAuth(sessionsDB, users, testTTL, nil, nil)
|
||||
globalContext.auth = auth
|
||||
|
||||
mux := http.NewServeMux()
|
||||
globalContext.mux = mux
|
||||
|
||||
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
|
||||
creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds))
|
||||
r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, r)
|
||||
|
||||
var loginCookie *http.Cookie
|
||||
for _, c := range w.Result().Cookies() {
|
||||
if c.Name == sessionCookieName {
|
||||
loginCookie = c
|
||||
}
|
||||
}
|
||||
require.NotNil(t, loginCookie)
|
||||
|
||||
testCases := []struct {
|
||||
url string
|
||||
method string
|
||||
authMethod int
|
||||
wantCode int
|
||||
}{{
|
||||
url: "/",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
url: "/control/i18n/change_language",
|
||||
method: http.MethodPost,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/i18n/change_language",
|
||||
method: http.MethodPost,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusInternalServerError,
|
||||
}, {
|
||||
url: "/control/i18n/change_language",
|
||||
method: http.MethodPost,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusInternalServerError,
|
||||
}, {
|
||||
url: "/control/i18n/current_language",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/i18n/current_language",
|
||||
method: http.MethodGet,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/i18n/current_language",
|
||||
method: http.MethodGet,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/logout",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/logout",
|
||||
method: http.MethodGet,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
url: "/control/profile",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/profile",
|
||||
method: http.MethodGet,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/profile",
|
||||
method: http.MethodGet,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/profile/update",
|
||||
method: http.MethodPut,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/profile/update",
|
||||
method: http.MethodPut,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusBadRequest,
|
||||
}, {
|
||||
url: "/control/profile/update",
|
||||
method: http.MethodPut,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusBadRequest,
|
||||
}, {
|
||||
url: "/control/status",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/status",
|
||||
method: http.MethodGet,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/status",
|
||||
method: http.MethodGet,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/update",
|
||||
method: http.MethodPost,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/version.json",
|
||||
method: http.MethodGet,
|
||||
authMethod: authNone,
|
||||
wantCode: http.StatusForbidden,
|
||||
}, {
|
||||
url: "/control/version.json",
|
||||
method: http.MethodGet,
|
||||
authMethod: authBasic,
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
url: "/control/version.json",
|
||||
method: http.MethodGet,
|
||||
authMethod: authCookie,
|
||||
wantCode: http.StatusOK,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.url, func(t *testing.T) {
|
||||
r = httptest.NewRequest(tc.method, tc.url, nil)
|
||||
switch tc.authMethod {
|
||||
case authNone:
|
||||
// Go on.
|
||||
case authBasic:
|
||||
r.SetBasicAuth(userName, userPassword)
|
||||
case authCookie:
|
||||
r.AddCookie(loginCookie)
|
||||
default:
|
||||
panic("unrecognized auth method")
|
||||
}
|
||||
|
||||
h, pattern := mux.Handler(r)
|
||||
require.NotEmpty(t, pattern)
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, tc.wantCode, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("logout", func(t *testing.T) {
|
||||
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
|
||||
r.AddCookie(loginCookie)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
r = httptest.NewRequest(http.MethodGet, "/control/logout", nil)
|
||||
r.AddCookie(loginCookie)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusFound, w.Code)
|
||||
|
||||
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
|
||||
r.AddCookie(loginCookie)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// implements http.ResponseWriter
|
||||
type testResponseWriter struct {
|
||||
hdr http.Header
|
||||
|
||||
@@ -114,6 +114,7 @@ func TestValidateCertificates(t *testing.T) {
|
||||
//
|
||||
// The global variables are:
|
||||
// - [configuration]
|
||||
// - [homeContext.auth]
|
||||
// - [homeContext.clients.storage]
|
||||
// - [homeContext.dnsServer]
|
||||
// - [homeContext.firstRun]
|
||||
@@ -126,6 +127,7 @@ func storeGlobals(tb testing.TB) {
|
||||
tb.Helper()
|
||||
|
||||
prevConfig := config
|
||||
auth := globalContext.auth
|
||||
storage := globalContext.clients.storage
|
||||
dnsServer := globalContext.dnsServer
|
||||
firstRun := globalContext.firstRun
|
||||
@@ -134,6 +136,7 @@ func storeGlobals(tb testing.TB) {
|
||||
|
||||
tb.Cleanup(func() {
|
||||
config = prevConfig
|
||||
globalContext.auth = auth
|
||||
globalContext.clients.storage = storage
|
||||
globalContext.dnsServer = dnsServer
|
||||
globalContext.firstRun = firstRun
|
||||
|
||||
Reference in New Issue
Block a user