Files
mal/internal/auth/handler_middleware_test.go

256 lines
7.6 KiB
Go

package auth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"mal/internal/db"
"mal/internal/domain"
"github.com/gin-gonic/gin"
)
func TestHandleAPILogin(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &fakeAuthService{
apiToken: "token-1",
apiUser: &domain.User{User: db.User{ID: "user-1", Username: "alice", AvatarUrl: "avatar.png"}},
}
router := gin.New()
NewAuthHandler(svc).Register(router)
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"alice","password":"correct","name":"phone"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), `"token":"token-1"`) {
t.Fatalf("response missing token: %s", rec.Body.String())
}
if svc.apiLoginName != "phone" {
t.Fatalf("api token name = %q, want phone", svc.apiLoginName)
}
}
func TestHandleAPILoginRejectsInvalidRequests(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body string
loginErr error
wantStatus int
}{
{name: "bad json", body: `{`, wantStatus: http.StatusBadRequest},
{name: "missing password", body: `{"username":"alice"}`, wantStatus: http.StatusBadRequest},
{name: "bad credentials", body: `{"username":"alice","password":"wrong"}`, loginErr: ErrWrongPassword, wantStatus: http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &fakeAuthService{apiLoginErr: tt.loginErr}
router := gin.New()
NewAuthHandler(svc).Register(router)
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/api/auth/login", strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Fatalf("status = %d, want %d; body=%s", rec.Code, tt.wantStatus, rec.Body.String())
}
})
}
}
func TestAuthMiddlewareAllowsPublicRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &fakeAuthService{}
router := gin.New()
router.Use(AuthMiddleware(svc))
router.GET("/static/app.js", func(c *gin.Context) { c.String(http.StatusOK, "asset") })
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/static/app.js", nil)
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if svc.validateSessionCalled || svc.validateAPITokenCalled {
t.Fatalf("public route should not authenticate")
}
}
func TestAuthMiddlewareAuthenticatesAPIBearerToken(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &fakeAuthService{user: &domain.User{User: db.User{ID: "user-1", Username: "alice"}}}
router := gin.New()
router.Use(AuthMiddleware(svc))
router.GET("/api/me", func(c *gin.Context) {
user, _ := c.Get("User")
if user.(*domain.User).ID != "user-1" {
c.Status(http.StatusTeapot)
return
}
c.Status(http.StatusOK)
})
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/api/me", nil)
req.Header.Set("Authorization", "Bearer api-token")
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if svc.validatedAPIToken != "api-token" {
t.Fatalf("validated api token = %q, want api-token", svc.validatedAPIToken)
}
if svc.refreshSessionCalled {
t.Fatalf("bearer token auth should not refresh cookie session")
}
}
func TestAuthMiddlewareAuthenticatesCookieSessionAndRefreshes(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &fakeAuthService{user: &domain.User{User: db.User{ID: "user-1", Username: "alice"}}}
router := gin.New()
router.Use(AuthMiddleware(svc))
router.GET("/", func(c *gin.Context) { c.Status(http.StatusOK) })
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "session_id", Value: "session-1"})
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if svc.validatedSessionID != "session-1" {
t.Fatalf("validated session id = %q, want session-1", svc.validatedSessionID)
}
if svc.refreshedSessionID != "session-1" {
t.Fatalf("refreshed session id = %q, want session-1", svc.refreshedSessionID)
}
if got := rec.Header().Values("Set-Cookie"); len(got) == 0 || !strings.Contains(got[0], "session_id=session-1") {
t.Fatalf("Set-Cookie = %v, want refreshed session cookie", got)
}
}
func TestAuthMiddlewareRejectsUnauthenticatedRequests(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
method string
path string
wantStatus int
wantHeader string
}{
{name: "api", method: http.MethodGet, path: "/api/me", wantStatus: http.StatusUnauthorized},
{name: "page", method: http.MethodGet, path: "/", wantStatus: http.StatusSeeOther, wantHeader: "/login"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware(&fakeAuthService{validateErr: errors.New("no auth")}))
router.Handle(tt.method, tt.path, func(c *gin.Context) { c.Status(http.StatusOK) })
rec := httptest.NewRecorder()
req := httptest.NewRequestWithContext(context.Background(), tt.method, tt.path, nil)
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Fatalf("status = %d, want %d", rec.Code, tt.wantStatus)
}
if tt.wantHeader != "" && rec.Header().Get("Location") != tt.wantHeader {
t.Fatalf("Location = %q, want %q", rec.Header().Get("Location"), tt.wantHeader)
}
})
}
}
type fakeAuthService struct {
user *domain.User
apiToken string
apiUser *domain.User
loginErr error
apiLoginErr error
validateErr error
apiLoginName string
validatedSessionID string
validatedAPIToken string
refreshedSessionID string
loggedOutSessionID string
validateSessionCalled bool
validateAPITokenCalled bool
refreshSessionCalled bool
revokedAPITokensForUser string
}
func (s *fakeAuthService) Login(_ context.Context, _, _ string) (*domain.Session, error) {
if s.loginErr != nil {
return nil, s.loginErr
}
return &domain.Session{Session: db.Session{ID: "session-1", UserID: "user-1"}}, nil
}
func (s *fakeAuthService) LoginForAPIToken(_ context.Context, _, _, name string) (string, *domain.User, error) {
s.apiLoginName = name
if s.apiLoginErr != nil {
return "", nil, s.apiLoginErr
}
return s.apiToken, s.apiUser, nil
}
func (s *fakeAuthService) ValidateSession(_ context.Context, sessionID string) (*domain.User, error) {
s.validateSessionCalled = true
s.validatedSessionID = sessionID
if s.validateErr != nil {
return nil, s.validateErr
}
return s.user, nil
}
func (s *fakeAuthService) RefreshSession(_ context.Context, sessionID string) error {
s.refreshSessionCalled = true
s.refreshedSessionID = sessionID
return nil
}
func (s *fakeAuthService) ValidateAPIToken(_ context.Context, token string) (*domain.User, error) {
s.validateAPITokenCalled = true
s.validatedAPIToken = token
if s.validateErr != nil {
return nil, s.validateErr
}
return s.user, nil
}
func (s *fakeAuthService) Logout(_ context.Context, sessionID string) error {
s.loggedOutSessionID = sessionID
return nil
}
func (s *fakeAuthService) RevokeAllAPITokensForUser(_ context.Context, userID string) error {
s.revokedAPITokensForUser = userID
return nil
}