From 81017516dd7653e0561cc34dbbb5c7d7a297102e Mon Sep 17 00:00:00 2001 From: mkelvers Date: Sun, 19 Apr 2026 21:06:00 +0200 Subject: [PATCH] refactor: extract access policy --- internal/shared/middleware/access.go | 67 +++++++++++++++++++ internal/shared/middleware/access_test.go | 79 +++++++++++++++++++++++ internal/shared/middleware/auth.go | 23 +------ 3 files changed, 147 insertions(+), 22 deletions(-) create mode 100644 internal/shared/middleware/access.go create mode 100644 internal/shared/middleware/access_test.go diff --git a/internal/shared/middleware/access.go b/internal/shared/middleware/access.go new file mode 100644 index 0000000..f088bd7 --- /dev/null +++ b/internal/shared/middleware/access.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "net/http" + "strings" + + "mal/internal/database" +) + +type AccessPolicy struct { + PublicPaths map[string]struct{} + PublicHeads []string +} + +func NewAccessPolicy() AccessPolicy { + return AccessPolicy{ + PublicPaths: map[string]struct{}{ + "/": {}, + "/login": {}, + "/search": {}, + "/api/search": {}, + "/api/search-quick": {}, + }, + PublicHeads: []string{ + "/static/", + "/dist/", + }, + } +} + +func (p AccessPolicy) IsPublicPath(path string) bool { + if _, ok := p.PublicPaths[path]; ok { + return true + } + + for _, head := range p.PublicHeads { + if strings.HasPrefix(path, head) { + return true + } + } + + return false +} + +func RequireGlobalAuthWithPolicy(policy AccessPolicy) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if policy.IsPublicPath(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + user, ok := r.Context().Value(UserContextKey).(*database.User) + if !ok || user == nil { + if strings.HasPrefix(r.URL.Path, "/api/") || r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/login") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } else { + http.Redirect(w, r, "/login", http.StatusFound) + } + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/shared/middleware/access_test.go b/internal/shared/middleware/access_test.go new file mode 100644 index 0000000..6954311 --- /dev/null +++ b/internal/shared/middleware/access_test.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "mal/internal/database" +) + +func TestAccessPolicy_IsPublicPath(t *testing.T) { + t.Parallel() + + policy := NewAccessPolicy() + + if !policy.IsPublicPath("/") { + t.Fatal("expected / to be public") + } + + if !policy.IsPublicPath("/api/search") { + t.Fatal("expected /api/search to be public") + } + + if !policy.IsPublicPath("/static/app.css") { + t.Fatal("expected /static/app.css to be public") + } + + if policy.IsPublicPath("/watchlist") { + t.Fatal("expected /watchlist to be private") + } +} + +func TestRequireGlobalAuthWithPolicy_ProtectedPath(t *testing.T) { + t.Parallel() + + policy := AccessPolicy{ + PublicPaths: map[string]struct{}{"/public": {}}, + } + + h := RequireGlobalAuthWithPolicy(policy)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/private", nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("expected redirect status, got %d", rec.Code) + } + + if location := rec.Header().Get("Location"); location != "/login" { + t.Fatalf("expected redirect to /login, got %q", location) + } +} + +func TestRequireGlobalAuthWithPolicy_AllowsAuthenticatedUser(t *testing.T) { + t.Parallel() + + policy := AccessPolicy{ + PublicPaths: map[string]struct{}{}, + } + + h := RequireGlobalAuthWithPolicy(policy)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/private", nil) + ctx := context.WithValue(req.Context(), UserContextKey, &database.User{ID: "user-1"}) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req.WithContext(ctx)) + + if rec.Code != http.StatusNoContent { + t.Fatalf("expected status %d, got %d", http.StatusNoContent, rec.Code) + } +} diff --git a/internal/shared/middleware/auth.go b/internal/shared/middleware/auth.go index 7d8fc9c..cf618cb 100644 --- a/internal/shared/middleware/auth.go +++ b/internal/shared/middleware/auth.go @@ -56,28 +56,7 @@ func RequireAuth(next http.Handler) http.Handler { } func RequireGlobalAuth(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Allow unauthenticated access to auth pages, search, and static files - if r.URL.Path == "/login" || - strings.HasPrefix(r.URL.Path, "/static/") || strings.HasPrefix(r.URL.Path, "/dist/") || - r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" || - r.URL.Path == "/" { - next.ServeHTTP(w, r) - return - } - - user, ok := r.Context().Value(UserContextKey).(*database.User) - if !ok || user == nil { - if strings.HasPrefix(r.URL.Path, "/api/") || r.Header.Get("HX-Request") == "true" { - w.Header().Set("HX-Redirect", "/login") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - } else { - http.Redirect(w, r, "/login", http.StatusFound) - } - return - } - next.ServeHTTP(w, r) - }) + return RequireGlobalAuthWithPolicy(NewAccessPolicy())(next) } func GetUser(ctx context.Context) *database.User {