From b16b3edf4dcc5613ef78628829256d6a2550e53a Mon Sep 17 00:00:00 2001 From: mkelvers Date: Wed, 24 Jun 2026 16:08:45 +0200 Subject: [PATCH] test: add auth handler middleware and service unit tests --- internal/auth/handler_middleware_test.go | 255 +++++++++++++++++++++++ internal/auth/service_test.go | 243 +++++++++++++++++++++ 2 files changed, 498 insertions(+) create mode 100644 internal/auth/handler_middleware_test.go create mode 100644 internal/auth/service_test.go diff --git a/internal/auth/handler_middleware_test.go b/internal/auth/handler_middleware_test.go new file mode 100644 index 0000000..149975a --- /dev/null +++ b/internal/auth/handler_middleware_test.go @@ -0,0 +1,255 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "mal/internal/db" + "mal/internal/domain" + + "github.com/gin-gonic/gin" +) + +func TestHandleAPILogin(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &fakeAuthService{ + apiToken: "token-1", + apiUser: &domain.User{User: db.User{ID: "user-1", Username: "alice", AvatarUrl: "avatar.png"}}, + } + router := gin.New() + NewAuthHandler(svc).Register(router) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"alice","password":"correct","name":"phone"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"token":"token-1"`) { + t.Fatalf("response missing token: %s", rec.Body.String()) + } + if svc.apiLoginName != "phone" { + t.Fatalf("api token name = %q, want phone", svc.apiLoginName) + } +} + +func TestHandleAPILoginRejectsInvalidRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + body string + loginErr error + wantStatus int + }{ + {name: "bad json", body: `{`, wantStatus: http.StatusBadRequest}, + {name: "missing password", body: `{"username":"alice"}`, wantStatus: http.StatusBadRequest}, + {name: "bad credentials", body: `{"username":"alice","password":"wrong"}`, loginErr: ErrWrongPassword, wantStatus: http.StatusUnauthorized}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &fakeAuthService{apiLoginErr: tt.loginErr} + router := gin.New() + NewAuthHandler(svc).Register(router) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/api/auth/login", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d; body=%s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } +} + +func TestAuthMiddlewareAllowsPublicRoutes(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &fakeAuthService{} + router := gin.New() + router.Use(AuthMiddleware(svc)) + router.GET("/static/app.js", func(c *gin.Context) { c.String(http.StatusOK, "asset") }) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/static/app.js", nil) + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if svc.validateSessionCalled || svc.validateAPITokenCalled { + t.Fatalf("public route should not authenticate") + } +} + +func TestAuthMiddlewareAuthenticatesAPIBearerToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &fakeAuthService{user: &domain.User{User: db.User{ID: "user-1", Username: "alice"}}} + router := gin.New() + router.Use(AuthMiddleware(svc)) + router.GET("/api/me", func(c *gin.Context) { + user, _ := c.Get("User") + if user.(*domain.User).ID != "user-1" { + c.Status(http.StatusTeapot) + return + } + c.Status(http.StatusOK) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/api/me", nil) + req.Header.Set("Authorization", "Bearer api-token") + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if svc.validatedAPIToken != "api-token" { + t.Fatalf("validated api token = %q, want api-token", svc.validatedAPIToken) + } + if svc.refreshSessionCalled { + t.Fatalf("bearer token auth should not refresh cookie session") + } +} + +func TestAuthMiddlewareAuthenticatesCookieSessionAndRefreshes(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &fakeAuthService{user: &domain.User{User: db.User{ID: "user-1", Username: "alice"}}} + router := gin.New() + router.Use(AuthMiddleware(svc)) + router.GET("/", func(c *gin.Context) { c.Status(http.StatusOK) }) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "session_id", Value: "session-1"}) + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if svc.validatedSessionID != "session-1" { + t.Fatalf("validated session id = %q, want session-1", svc.validatedSessionID) + } + if svc.refreshedSessionID != "session-1" { + t.Fatalf("refreshed session id = %q, want session-1", svc.refreshedSessionID) + } + if got := rec.Header().Values("Set-Cookie"); len(got) == 0 || !strings.Contains(got[0], "session_id=session-1") { + t.Fatalf("Set-Cookie = %v, want refreshed session cookie", got) + } +} + +func TestAuthMiddlewareRejectsUnauthenticatedRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + method string + path string + wantStatus int + wantHeader string + }{ + {name: "api", method: http.MethodGet, path: "/api/me", wantStatus: http.StatusUnauthorized}, + {name: "page", method: http.MethodGet, path: "/", wantStatus: http.StatusSeeOther, wantHeader: "/login"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := gin.New() + router.Use(AuthMiddleware(&fakeAuthService{validateErr: errors.New("no auth")})) + router.Handle(tt.method, tt.path, func(c *gin.Context) { c.Status(http.StatusOK) }) + + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(context.Background(), tt.method, tt.path, nil) + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d", rec.Code, tt.wantStatus) + } + if tt.wantHeader != "" && rec.Header().Get("Location") != tt.wantHeader { + t.Fatalf("Location = %q, want %q", rec.Header().Get("Location"), tt.wantHeader) + } + }) + } +} + +type fakeAuthService struct { + user *domain.User + + apiToken string + apiUser *domain.User + + loginErr error + apiLoginErr error + validateErr error + + apiLoginName string + validatedSessionID string + validatedAPIToken string + refreshedSessionID string + loggedOutSessionID string + validateSessionCalled bool + validateAPITokenCalled bool + refreshSessionCalled bool + revokedAPITokensForUser string +} + +func (s *fakeAuthService) Login(_ context.Context, _, _ string) (*domain.Session, error) { + if s.loginErr != nil { + return nil, s.loginErr + } + return &domain.Session{Session: db.Session{ID: "session-1", UserID: "user-1"}}, nil +} + +func (s *fakeAuthService) LoginForAPIToken(_ context.Context, _, _, name string) (string, *domain.User, error) { + s.apiLoginName = name + if s.apiLoginErr != nil { + return "", nil, s.apiLoginErr + } + return s.apiToken, s.apiUser, nil +} + +func (s *fakeAuthService) ValidateSession(_ context.Context, sessionID string) (*domain.User, error) { + s.validateSessionCalled = true + s.validatedSessionID = sessionID + if s.validateErr != nil { + return nil, s.validateErr + } + return s.user, nil +} + +func (s *fakeAuthService) RefreshSession(_ context.Context, sessionID string) error { + s.refreshSessionCalled = true + s.refreshedSessionID = sessionID + return nil +} + +func (s *fakeAuthService) ValidateAPIToken(_ context.Context, token string) (*domain.User, error) { + s.validateAPITokenCalled = true + s.validatedAPIToken = token + if s.validateErr != nil { + return nil, s.validateErr + } + return s.user, nil +} + +func (s *fakeAuthService) Logout(_ context.Context, sessionID string) error { + s.loggedOutSessionID = sessionID + return nil +} + +func (s *fakeAuthService) RevokeAllAPITokensForUser(_ context.Context, userID string) error { + s.revokedAPITokensForUser = userID + return nil +} diff --git a/internal/auth/service_test.go b/internal/auth/service_test.go new file mode 100644 index 0000000..8c5bafb --- /dev/null +++ b/internal/auth/service_test.go @@ -0,0 +1,243 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "testing" + "time" + + "mal/internal/db" + "mal/internal/domain" + + "golang.org/x/crypto/bcrypt" +) + +func TestAuthServiceLogin(t *testing.T) { + passwordHash := hashPassword(t, "correct") + repo := &fakeAuthRepository{ + usersByUsername: map[string]*domain.User{ + "alice": {User: db.User{ID: "user-1", Username: "alice", PasswordHash: passwordHash}}, + }, + } + svc := NewAuthService(repo, &fakeAuditService{}) + + session, err := svc.Login(context.Background(), "alice", "correct") + if err != nil { + t.Fatalf("Login: %v", err) + } + if session.UserID != "user-1" { + t.Fatalf("session user id = %q, want %q", session.UserID, "user-1") + } + if session.ID == "" { + t.Fatalf("expected generated session id") + } + if repo.createdSessionUserID != "user-1" { + t.Fatalf("created session user id = %q, want user-1", repo.createdSessionUserID) + } +} + +func TestAuthServiceLoginRejectsMissingUserAndWrongPassword(t *testing.T) { + passwordHash := hashPassword(t, "correct") + repo := &fakeAuthRepository{ + usersByUsername: map[string]*domain.User{ + "alice": {User: db.User{ID: "user-1", Username: "alice", PasswordHash: passwordHash}}, + }, + } + svc := NewAuthService(repo, &fakeAuditService{}) + + if _, err := svc.Login(context.Background(), "missing", "correct"); !errors.Is(err, ErrUserNotFound) { + t.Fatalf("missing user error = %v, want %v", err, ErrUserNotFound) + } + if _, err := svc.Login(context.Background(), "alice", "wrong"); !errors.Is(err, ErrWrongPassword) { + t.Fatalf("wrong password error = %v, want %v", err, ErrWrongPassword) + } +} + +func TestAuthServiceValidateSession(t *testing.T) { + repo := &fakeAuthRepository{ + usersByID: map[string]*domain.User{ + "user-1": {User: db.User{ID: "user-1", Username: "alice"}}, + }, + sessions: map[string]*domain.Session{ + "fresh": {Session: db.Session{ID: "fresh", UserID: "user-1", ExpiresAt: time.Now().Add(time.Hour)}}, + "expired": {Session: db.Session{ID: "expired", UserID: "user-1", ExpiresAt: time.Now().Add(-time.Hour)}}, + }, + } + svc := NewAuthService(repo, &fakeAuditService{}) + + user, err := svc.ValidateSession(context.Background(), "fresh") + if err != nil { + t.Fatalf("ValidateSession fresh: %v", err) + } + if user == nil || user.ID != "user-1" { + t.Fatalf("validated user = %#v, want user-1", user) + } + + if _, err := svc.ValidateSession(context.Background(), "expired"); err == nil || err.Error() != "session expired" { + t.Fatalf("expired session error = %v, want session expired", err) + } + if !repo.deletedSessions["expired"] { + t.Fatalf("expected expired session to be deleted") + } +} + +func TestAuthServiceLoginForAPITokenCreatesTokenAndAuditEvent(t *testing.T) { + passwordHash := hashPassword(t, "correct") + repo := &fakeAuthRepository{ + usersByUsername: map[string]*domain.User{ + "alice": {User: db.User{ID: "user-1", Username: "alice", PasswordHash: passwordHash}}, + }, + } + auditSvc := &fakeAuditService{} + svc := NewAuthService(repo, auditSvc) + + token, user, err := svc.LoginForAPIToken(context.Background(), "alice", "correct", " phone ") + if err != nil { + t.Fatalf("LoginForAPIToken: %v", err) + } + if token == "" { + t.Fatalf("expected raw token") + } + if user == nil || user.ID != "user-1" { + t.Fatalf("user = %#v, want user-1", user) + } + if repo.createdAPITokenName != "phone" { + t.Fatalf("api token name = %q, want phone", repo.createdAPITokenName) + } + if repo.createdAPITokenHash == "" || repo.createdAPITokenHash == token { + t.Fatalf("expected stored token hash, got %q", repo.createdAPITokenHash) + } + if len(auditSvc.events) != 1 || auditSvc.events[0].Action != "api_token_created" { + t.Fatalf("audit events = %#v, want api_token_created", auditSvc.events) + } +} + +func TestAuthServiceValidateAPIToken(t *testing.T) { + rawToken := "secret-token" + sum := sha256.Sum256([]byte(rawToken)) + tokenHash := hex.EncodeToString(sum[:]) + repo := &fakeAuthRepository{ + usersByID: map[string]*domain.User{ + "user-1": {User: db.User{ID: "user-1", Username: "alice"}}, + }, + apiTokensByHash: map[string]*domain.APIToken{ + tokenHash: {ApiToken: db.ApiToken{ID: "token-1", UserID: "user-1", TokenHash: tokenHash}}, + }, + } + svc := NewAuthService(repo, &fakeAuditService{}) + + user, err := svc.ValidateAPIToken(context.Background(), " "+rawToken+" ") + if err != nil { + t.Fatalf("ValidateAPIToken: %v", err) + } + if user == nil || user.ID != "user-1" { + t.Fatalf("user = %#v, want user-1", user) + } + if repo.touchedTokenID != "token-1" { + t.Fatalf("touched token id = %q, want token-1", repo.touchedTokenID) + } +} + +func TestAuthServiceRevokeAllAPITokensForUser(t *testing.T) { + repo := &fakeAuthRepository{} + auditSvc := &fakeAuditService{} + svc := NewAuthService(repo, auditSvc) + + if err := svc.RevokeAllAPITokensForUser(context.Background(), "user-1"); err != nil { + t.Fatalf("RevokeAllAPITokensForUser: %v", err) + } + if repo.revokedUserID != "user-1" { + t.Fatalf("revoked user id = %q, want user-1", repo.revokedUserID) + } + if len(auditSvc.events) != 1 || auditSvc.events[0].Action != "api_token_revoked_all" { + t.Fatalf("audit events = %#v, want api_token_revoked_all", auditSvc.events) + } + + if err := svc.RevokeAllAPITokensForUser(context.Background(), " "); err == nil || err.Error() != "user id missing" { + t.Fatalf("blank user id error = %v, want user id missing", err) + } +} + +func hashPassword(t *testing.T, password string) string { + t.Helper() + + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + if err != nil { + t.Fatalf("GenerateFromPassword: %v", err) + } + return string(hash) +} + +type fakeAuthRepository struct { + usersByUsername map[string]*domain.User + usersByID map[string]*domain.User + sessions map[string]*domain.Session + apiTokensByHash map[string]*domain.APIToken + + createdSessionUserID string + createdAPITokenHash string + createdAPITokenName string + touchedTokenID string + revokedUserID string + deletedSessions map[string]bool +} + +func (r *fakeAuthRepository) GetUserByUsername(_ context.Context, username string) (*domain.User, error) { + return r.usersByUsername[username], nil +} + +func (r *fakeAuthRepository) GetUserByID(_ context.Context, id string) (*domain.User, error) { + return r.usersByID[id], nil +} + +func (r *fakeAuthRepository) CreateSession(_ context.Context, userID string, sessionID string) (*domain.Session, error) { + r.createdSessionUserID = userID + return &domain.Session{Session: db.Session{ID: sessionID, UserID: userID, ExpiresAt: time.Now().Add(domain.SessionLifetime)}}, nil +} + +func (r *fakeAuthRepository) GetSession(_ context.Context, sessionID string) (*domain.Session, error) { + return r.sessions[sessionID], nil +} + +func (r *fakeAuthRepository) RefreshSession(_ context.Context, _ string, _ time.Time) error { + return nil +} + +func (r *fakeAuthRepository) DeleteSession(_ context.Context, sessionID string) error { + if r.deletedSessions == nil { + r.deletedSessions = make(map[string]bool) + } + r.deletedSessions[sessionID] = true + return nil +} + +func (r *fakeAuthRepository) CreateAPIToken(_ context.Context, userID, tokenHash, name string) (*domain.APIToken, error) { + r.createdAPITokenHash = tokenHash + r.createdAPITokenName = name + return &domain.APIToken{ApiToken: db.ApiToken{ID: "token-1", UserID: userID, TokenHash: tokenHash, Name: name}}, nil +} + +func (r *fakeAuthRepository) GetAPITokenByHash(_ context.Context, tokenHash string) (*domain.APIToken, error) { + return r.apiTokensByHash[tokenHash], nil +} + +func (r *fakeAuthRepository) TouchAPITokenLastUsedAt(_ context.Context, tokenID string) error { + r.touchedTokenID = tokenID + return nil +} + +func (r *fakeAuthRepository) RevokeAllAPITokensForUser(_ context.Context, userID string) error { + r.revokedUserID = userID + return nil +} + +type fakeAuditService struct { + events []domain.AuditEvent +} + +func (s *fakeAuditService) Record(_ context.Context, event domain.AuditEvent) error { + s.events = append(s.events, event) + return nil +}