diff --git a/internal/home/authhttp_internal_test.go b/internal/home/authhttp_internal_test.go index b2217cbd..3385a07a 100644 --- a/internal/home/authhttp_internal_test.go +++ b/internal/home/authhttp_internal_test.go @@ -1,6 +1,8 @@ package home import ( + "bytes" + "encoding/json" "net/http" "net/http/httptest" "net/netip" @@ -9,11 +11,13 @@ import ( "path/filepath" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" ) // TODO(s.chzhen): !! Add more tests. @@ -74,7 +78,7 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) { code: http.StatusOK, }, { url: "/control/login", - method: http.MethodGet, + method: http.MethodPost, code: http.StatusFound, }, { url: "/control/logout", @@ -117,6 +121,235 @@ func TestAuth_ServeHTTP_first_run(t *testing.T) { } } +func TestAuth_ServeHTTP(t *testing.T) { + storeGlobals(t) + + const ( + authNone = iota + authBasic + authCookie + ) + + const ( + testTTL = 60 + userName = "name" + userPassword = "password" + ) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost) + require.NoError(t, err) + + sessionsDB := filepath.Join(t.TempDir(), "sessions.db") + + users := []webUser{{ + Name: userName, + PasswordHash: string(passwordHash), + }} + auth := InitAuth(sessionsDB, users, testTTL, nil, nil) + globalContext.auth = auth + + mux := http.NewServeMux() + globalContext.mux = mux + + tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: logger, + configModified: func() {}, + }) + require.NoError(t, err) + + web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false) + require.NoError(t, err) + + globalContext.web = web + + creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword}) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds)) + r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON) + + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + + var loginCookie *http.Cookie + for _, c := range w.Result().Cookies() { + if c.Name == sessionCookieName { + loginCookie = c + } + } + require.NotNil(t, loginCookie) + + testCases := []struct { + url string + method string + authMethod int + wantCode int + }{{ + url: "/", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusFound, + }, { + url: "/control/i18n/change_language", + method: http.MethodPost, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/i18n/change_language", + method: http.MethodPost, + authMethod: authBasic, + wantCode: http.StatusInternalServerError, + }, { + url: "/control/i18n/change_language", + method: http.MethodPost, + authMethod: authCookie, + wantCode: http.StatusInternalServerError, + }, { + url: "/control/i18n/current_language", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/i18n/current_language", + method: http.MethodGet, + authMethod: authBasic, + wantCode: http.StatusOK, + }, { + url: "/control/i18n/current_language", + method: http.MethodGet, + authMethod: authCookie, + wantCode: http.StatusOK, + }, { + url: "/control/logout", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/logout", + method: http.MethodGet, + authMethod: authBasic, + wantCode: http.StatusFound, + }, { + url: "/control/profile", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/profile", + method: http.MethodGet, + authMethod: authBasic, + wantCode: http.StatusOK, + }, { + url: "/control/profile", + method: http.MethodGet, + authMethod: authCookie, + wantCode: http.StatusOK, + }, { + url: "/control/profile/update", + method: http.MethodPut, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/profile/update", + method: http.MethodPut, + authMethod: authBasic, + wantCode: http.StatusBadRequest, + }, { + url: "/control/profile/update", + method: http.MethodPut, + authMethod: authCookie, + wantCode: http.StatusBadRequest, + }, { + url: "/control/status", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/status", + method: http.MethodGet, + authMethod: authBasic, + wantCode: http.StatusOK, + }, { + url: "/control/status", + method: http.MethodGet, + authMethod: authCookie, + wantCode: http.StatusOK, + }, { + url: "/control/update", + method: http.MethodPost, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/version.json", + method: http.MethodGet, + authMethod: authNone, + wantCode: http.StatusForbidden, + }, { + url: "/control/version.json", + method: http.MethodGet, + authMethod: authBasic, + wantCode: http.StatusOK, + }, { + url: "/control/version.json", + method: http.MethodGet, + authMethod: authCookie, + wantCode: http.StatusOK, + }} + + for _, tc := range testCases { + t.Run(tc.url, func(t *testing.T) { + r = httptest.NewRequest(tc.method, tc.url, nil) + switch tc.authMethod { + case authNone: + // Go on. + case authBasic: + r.SetBasicAuth(userName, userPassword) + case authCookie: + r.AddCookie(loginCookie) + default: + panic("unrecognized auth method") + } + + h, pattern := mux.Handler(r) + require.NotEmpty(t, pattern) + + w = httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, tc.wantCode, w.Code) + }) + } + + t.Run("logout", func(t *testing.T) { + r = httptest.NewRequest(http.MethodGet, "/control/status", nil) + r.AddCookie(loginCookie) + w = httptest.NewRecorder() + + mux.ServeHTTP(w, r) + assert.Equal(t, http.StatusOK, w.Code) + + r = httptest.NewRequest(http.MethodGet, "/control/logout", nil) + r.AddCookie(loginCookie) + w = httptest.NewRecorder() + + mux.ServeHTTP(w, r) + assert.Equal(t, http.StatusFound, w.Code) + + r = httptest.NewRequest(http.MethodGet, "/control/status", nil) + r.AddCookie(loginCookie) + w = httptest.NewRecorder() + + mux.ServeHTTP(w, r) + assert.Equal(t, http.StatusForbidden, w.Code) + }) +} + // implements http.ResponseWriter type testResponseWriter struct { hdr http.Header diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index 1c239c81..a63efd0f 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -114,6 +114,7 @@ func TestValidateCertificates(t *testing.T) { // // The global variables are: // - [configuration] +// - [homeContext.auth] // - [homeContext.clients.storage] // - [homeContext.dnsServer] // - [homeContext.firstRun] @@ -126,6 +127,7 @@ func storeGlobals(tb testing.TB) { tb.Helper() prevConfig := config + auth := globalContext.auth storage := globalContext.clients.storage dnsServer := globalContext.dnsServer firstRun := globalContext.firstRun @@ -134,6 +136,7 @@ func storeGlobals(tb testing.TB) { tb.Cleanup(func() { config = prevConfig + globalContext.auth = auth globalContext.clients.storage = storage globalContext.dnsServer = dnsServer globalContext.firstRun = firstRun