diff --git a/api/anime/handler.go b/api/anime/handler.go index ac8f6d7..9342ae5 100644 --- a/api/anime/handler.go +++ b/api/anime/handler.go @@ -11,6 +11,7 @@ import ( "mal/integrations/jikan" ctxpkg "mal/internal/context" "mal/internal/db" + "mal/internal/middleware" "mal/templates" ) @@ -85,8 +86,9 @@ func (h *Handler) HandleCatalog(w http.ResponseWriter, r *http.Request) { var cw []database.GetContinueWatchingEntriesRow watchlistMap := make(map[int64]bool) var watchlistIDs []int64 - user, userOk := r.Context().Value(ctxpkg.UserKey).(*database.User) - if userOk && user != nil { + user := middleware.GetUser(r.Context()) + userOk := user != nil + if userOk { cw, _ = h.db.GetContinueWatchingEntries(r.Context(), user.ID) watchlist, _ := h.db.GetUserWatchList(r.Context(), user.ID) watchlistIDs = make([]int64, len(watchlist)) @@ -111,7 +113,7 @@ func (h *Handler) HandleCatalog(w http.ResponseWriter, r *http.Request) { } func (h *Handler) HandleBrowse(w http.ResponseWriter, r *http.Request) { - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) q := r.URL.Query().Get("q") animeType := r.URL.Query().Get("type") @@ -230,7 +232,7 @@ func (h *Handler) HandleAnimeDetails(w http.ResponseWriter, r *http.Request) { return } - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) var status string var watchlistIDs []int64 @@ -276,7 +278,7 @@ func (h *Handler) HandleHTMLWatchOrder(w http.ResponseWriter, r *http.Request) { return } - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) watchlistMap := make(map[int64]bool) if user != nil { watchlist, _ := h.db.GetUserWatchList(r.Context(), user.ID) diff --git a/api/playback/handler.go b/api/playback/handler.go index a09623c..07cad6e 100644 --- a/api/playback/handler.go +++ b/api/playback/handler.go @@ -12,8 +12,7 @@ import ( "sync" "mal/integrations/jikan" - ctxpkg "mal/internal/context" - database "mal/internal/db" + "mal/internal/middleware" "mal/templates" ) @@ -82,7 +81,7 @@ func (h *Handler) HandleWatchPage(w http.ResponseWriter, r *http.Request) { return episodes.Data[i].MalID < episodes.Data[j].MalID }) - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) currentEpID := r.URL.Query().Get("ep") if currentEpID == "" { @@ -238,7 +237,7 @@ func (h *Handler) HandleEpisodeData(w http.ResponseWriter, r *http.Request) { episodeID := parts[5] - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) userID := "" if user != nil { userID = user.ID diff --git a/api/watchlist/handler.go b/api/watchlist/handler.go index 650faa7..2b13c9e 100644 --- a/api/watchlist/handler.go +++ b/api/watchlist/handler.go @@ -6,8 +6,8 @@ import ( "net/http" "strconv" - ctxpkg "mal/internal/context" database "mal/internal/db" + "mal/internal/middleware" "mal/templates" ) @@ -29,7 +29,7 @@ func (h *Handler) HandleUpdateWatchlist(w http.ResponseWriter, r *http.Request) return } - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) if user == nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return @@ -58,7 +58,7 @@ func (h *Handler) HandleUpdateWatchlist(w http.ResponseWriter, r *http.Request) } func (h *Handler) HandleDeleteWatchlist(w http.ResponseWriter, r *http.Request) { - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) if user == nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return @@ -86,7 +86,7 @@ func (h *Handler) HandleDeleteContinueWatching(w http.ResponseWriter, r *http.Re } func (h *Handler) HandleGetWatchlist(w http.ResponseWriter, r *http.Request) { - user, _ := r.Context().Value(ctxpkg.UserKey).(*database.User) + user := middleware.GetUser(r.Context()) if user == nil { http.Redirect(w, r, "/login", http.StatusSeeOther) return diff --git a/internal/middleware/access.go b/internal/middleware/access.go index 03264f3..6abba55 100644 --- a/internal/middleware/access.go +++ b/internal/middleware/access.go @@ -3,9 +3,6 @@ package middleware import ( "net/http" "strings" - - "mal/internal/context" - "mal/internal/db" ) type AccessPolicy struct { @@ -47,7 +44,8 @@ func RequireGlobalAuthWithPolicy(policy AccessPolicy) func(http.Handler) http.Ha return } - user, ok := r.Context().Value(context.UserKey).(*database.User) + user := GetUser(r.Context()) + ok := user != nil if !ok || user == nil { if strings.HasPrefix(r.URL.Path, "/api/") || r.Header.Get("HX-Request") == "true" { w.Header().Set("HX-Redirect", "/login") diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index fcf225f..80a5b13 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -58,8 +58,8 @@ func RequireAuth(next http.Handler) http.Handler { } } - user, ok := r.Context().Value(ctxpkg.UserKey).(*database.User) - if !ok || user == nil { + user := GetUser(r.Context()) + if user == nil { if strings.HasPrefix(r.URL.Path, "/api/") { w.Header().Set("HX-Redirect", "/login") http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -73,8 +73,8 @@ func RequireAuth(next http.Handler) http.Handler { }) } -func GetUser(ctx interface{}) *database.User { - user, ok := ctx.(*database.User) +func GetUser(ctx context.Context) *database.User { + user, ok := ctx.Value(ctxpkg.UserKey).(*database.User) if !ok { return nil }