Files
AdGuardHome/internal/home/authhttp_internal_test.go
Stanislav Chzhen 6109e3575f home: add tests
2025-05-13 22:23:41 +03:00

512 lines
12 KiB
Go

package home
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"net/textproto"
"net/url"
"path/filepath"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
// TODO(s.chzhen): !! Add more tests.
func TestAuth_ServeHTTP_first_run(t *testing.T) {
storeGlobals(t)
globalContext.firstRun = true
mux := http.NewServeMux()
globalContext.mux = mux
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
require.NoError(t, err)
globalContext.web = web
testCases := []struct {
url string
method string
code int
}{{
url: "/",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/apple/doh.mobileconfig",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/apple/dot.mobileconfig",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/i18n/change_language",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/install/check_config",
method: http.MethodPost,
code: http.StatusBadRequest,
}, {
url: "/control/install/configure",
method: http.MethodPost,
code: http.StatusBadRequest,
}, {
url: "/control/install/get_addresses",
method: http.MethodGet,
code: http.StatusOK,
}, {
url: "/control/login",
method: http.MethodPost,
code: http.StatusFound,
}, {
url: "/control/logout",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/profile",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/profile/update",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/status",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/update",
method: http.MethodGet,
code: http.StatusFound,
}, {
url: "/control/version.json",
method: http.MethodGet,
code: http.StatusFound,
}}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
r := httptest.NewRequest(tc.method, tc.url, nil)
h, pattern := mux.Handler(r)
require.NotEmpty(t, pattern)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, tc.code, w.Code)
})
}
}
func TestAuth_ServeHTTP(t *testing.T) {
storeGlobals(t)
const (
authNone = iota
authBasic
authCookie
)
const (
testTTL = 60
userName = "name"
userPassword = "password"
)
var (
logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
)
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
require.NoError(t, err)
sessionsDB := filepath.Join(t.TempDir(), "sessions.db")
users := []webUser{{
Name: userName,
PasswordHash: string(passwordHash),
}}
auth := InitAuth(sessionsDB, users, testTTL, nil, nil)
globalContext.auth = auth
mux := http.NewServeMux()
globalContext.mux = mux
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
logger: logger,
configModified: func() {},
})
require.NoError(t, err)
web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false)
require.NoError(t, err)
globalContext.web = web
creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword})
require.NoError(t, err)
r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds))
r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
var loginCookie *http.Cookie
for _, c := range w.Result().Cookies() {
if c.Name == sessionCookieName {
loginCookie = c
}
}
require.NotNil(t, loginCookie)
testCases := []struct {
url string
method string
authMethod int
wantCode int
}{{
url: "/",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusFound,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authBasic,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/change_language",
method: http.MethodPost,
authMethod: authCookie,
wantCode: http.StatusInternalServerError,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/i18n/current_language",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/logout",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusFound,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/profile",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authBasic,
wantCode: http.StatusBadRequest,
}, {
url: "/control/profile/update",
method: http.MethodPut,
authMethod: authCookie,
wantCode: http.StatusBadRequest,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/status",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}, {
url: "/control/update",
method: http.MethodPost,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authNone,
wantCode: http.StatusForbidden,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authBasic,
wantCode: http.StatusOK,
}, {
url: "/control/version.json",
method: http.MethodGet,
authMethod: authCookie,
wantCode: http.StatusOK,
}}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
r = httptest.NewRequest(tc.method, tc.url, nil)
switch tc.authMethod {
case authNone:
// Go on.
case authBasic:
r.SetBasicAuth(userName, userPassword)
case authCookie:
r.AddCookie(loginCookie)
default:
panic("unrecognized auth method")
}
h, pattern := mux.Handler(r)
require.NotEmpty(t, pattern)
w = httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, tc.wantCode, w.Code)
})
}
t.Run("logout", func(t *testing.T) {
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/logout", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusFound, w.Code)
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
r.AddCookie(loginCookie)
w = httptest.NewRecorder()
mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)
})
}
// implements http.ResponseWriter
type testResponseWriter struct {
hdr http.Header
statusCode int
}
func (w *testResponseWriter) Header() http.Header {
return w.hdr
}
func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func TestAuthHTTP(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) {
handlerCalled = true
}
handler2 := optionalAuth(handler)
w := testResponseWriter{}
w.hdr = make(http.Header)
r := http.Request{}
r.Header = make(http.Header)
r.Method = http.MethodGet
// get / - we're redirected to login page
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.Equal(t, http.StatusFound, w.statusCode)
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
assert.False(t, handlerCalled)
// go to login page
loginURL := w.hdr.Get(httphdr.Location)
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
// perform login
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
require.NoError(t, err)
require.NotNil(t, cookie)
// get /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set(httphdr.Cookie, cookie.String())
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
// get / with basic auth
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.URL = &url.URL{Path: "/"}
r.SetBasicAuth("name", "password")
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Authorization)
// get login page with a valid cookie - we're redirected to /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set(httphdr.Cookie, cookie.String())
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.NotEmpty(t, w.hdr.Get(httphdr.Location))
assert.False(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
// get login page with an invalid cookie
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set(httphdr.Cookie, "bad")
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
globalContext.auth.Close()
}
func TestRealIP(t *testing.T) {
const remoteAddr = "1.2.3.4:5678"
testCases := []struct {
name string
header http.Header
remoteAddr string
wantErrMsg string
wantIP netip.Addr
}{{
name: "success_no_proxy",
header: nil,
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: netip.MustParseAddr("1.2.3.4"),
}, {
name: "success_proxy",
header: http.Header{
textproto.CanonicalMIMEHeaderKey(httphdr.XRealIP): []string{"1.2.3.5"},
},
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: netip.MustParseAddr("1.2.3.5"),
}, {
name: "success_proxy_multiple",
header: http.Header{
textproto.CanonicalMIMEHeaderKey(httphdr.XForwardedFor): []string{
"1.2.3.6, 1.2.3.5",
},
},
remoteAddr: remoteAddr,
wantErrMsg: "",
wantIP: netip.MustParseAddr("1.2.3.6"),
}, {
name: "error_no_proxy",
header: nil,
remoteAddr: "1:::2",
wantErrMsg: `getting ip from client addr: address 1:::2: ` +
`too many colons in address`,
wantIP: netip.Addr{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := &http.Request{
Header: tc.header,
RemoteAddr: tc.remoteAddr,
}
ip, err := realIP(r)
assert.Equal(t, tc.wantIP, ip)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}