From 5dd49e585a9d28fcf3f57142055e5732bd7157c2 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Thu, 28 May 2026 12:51:11 +0200 Subject: [PATCH] refactor: extract CurrentUser and CurrentUserID helpers --- internal/anime/handler.go | 72 ++++++++++------------------ internal/playback/handler/handler.go | 46 +++++------------- internal/server/user.go | 26 ++++++++++ internal/watchlist/handler.go | 25 ++-------- 4 files changed, 67 insertions(+), 102 deletions(-) create mode 100644 internal/server/user.go diff --git a/internal/anime/handler.go b/internal/anime/handler.go index ec10210..e3165b0 100644 --- a/internal/anime/handler.go +++ b/internal/anime/handler.go @@ -187,7 +187,7 @@ func (h *AnimeHandler) HandleProducers(c *gin.Context) { } func (h *AnimeHandler) HandleCatalog(c *gin.Context) { - user, _ := c.Get("User") + user := server.CurrentUser(c) c.HTML(http.StatusOK, "index.gohtml", gin.H{ "CurrentPath": "/", @@ -209,11 +209,7 @@ func (h *AnimeHandler) HandleCatalogContinue(c *gin.Context) { } func (h *AnimeHandler) renderCatalogSection(c *gin.Context, section string) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) data, err := h.svc.GetCatalogSection(c.Request.Context(), userID, section) if err != nil { observability.Warn( @@ -239,7 +235,7 @@ func (h *AnimeHandler) renderCatalogSection(c *gin.Context, section string) { } func (h *AnimeHandler) HandleDiscover(c *gin.Context) { - user, _ := c.Get("User") + user := server.CurrentUser(c) c.HTML(http.StatusOK, "discover.gohtml", gin.H{ "CurrentPath": "/discover", "User": user, @@ -259,11 +255,7 @@ func (h *AnimeHandler) HandleDiscoverTop(c *gin.Context) { } func (h *AnimeHandler) HandleDiscoverForYou(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) data, err := h.svc.GetDiscoverForYou(c.Request.Context(), userID) if err != nil { @@ -289,11 +281,7 @@ func (h *AnimeHandler) HandleDiscoverForYou(c *gin.Context) { } func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) data, err := h.svc.GetDiscoverSection(c.Request.Context(), userID, section) if err != nil { observability.Warn( @@ -319,7 +307,7 @@ func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) { } func (h *AnimeHandler) HandleSchedule(c *gin.Context) { - user, _ := c.Get("User") + user := server.CurrentUser(c) year, week := parseYearWeek(c) c.HTML(http.StatusOK, "schedule.gohtml", gin.H{ "CurrentPath": "/schedule", @@ -515,11 +503,8 @@ func (h *AnimeHandler) HandleBrowse(c *gin.Context) { return } - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + user := server.CurrentUser(c) + userID := server.CurrentUserID(c) animes := wrapAnimes(res.Animes) watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes) @@ -655,19 +640,19 @@ func (h *AnimeHandler) HandleAnimeDetails(c *gin.Context) { return } - user, _ := c.Get("User") + user := server.CurrentUser(c) status := "" var watchlistIDs []int64 ep := 0 var cwSeconds float64 - if u, ok := user.(*domain.User); ok { - entry, err := h.watchlistSvc.GetWatchListEntry(c.Request.Context(), u.ID, int64(id)) + if user != nil { + entry, err := h.watchlistSvc.GetWatchListEntry(c.Request.Context(), user.ID, int64(id)) if err == nil { status = entry.Status watchlistIDs = []int64{entry.AnimeID} } - cwEntry, err := h.watchlistSvc.GetContinueWatchingEntry(c.Request.Context(), u.ID, int64(id)) + cwEntry, err := h.watchlistSvc.GetContinueWatchingEntry(c.Request.Context(), user.ID, int64(id)) if err == nil && cwEntry.CurrentEpisode.Valid { ep = int(cwEntry.CurrentEpisode.Int64) cwSeconds = cwEntry.CurrentTimeSeconds @@ -692,11 +677,7 @@ func (h *AnimeHandler) HandleHTMLWatchOrder(c *gin.Context) { return } - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) relationsCtx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) defer cancel() @@ -745,11 +726,7 @@ func (h *AnimeHandler) HandleQuickSearch(c *gin.Context) { return } - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) animes := wrapAnimes(res.Animes) watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes) @@ -787,9 +764,8 @@ type commandPaletteItem struct { } func (h *AnimeHandler) HandleCommandPalette(c *gin.Context) { - user, _ := c.Get("User") - u, ok := user.(*domain.User) - if !ok { + user := server.CurrentUser(c) + if user == nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return } @@ -812,15 +788,15 @@ func (h *AnimeHandler) HandleCommandPalette(c *gin.Context) { } items = append(items, h.commandPaletteNavigationItems(query)...) - items = append(items, h.commandPaletteContinueItems(c, u.ID, query)...) - items = append(items, h.commandPalettePersonalItems(c, u.ID, query)...) + items = append(items, h.commandPaletteContinueItems(c, user.ID, query)...) + items = append(items, h.commandPalettePersonalItems(c, user.ID, query)...) c.JSON(http.StatusOK, items) return } - items = append(items, h.commandPaletteContinueItems(c, u.ID, query)...) + items = append(items, h.commandPaletteContinueItems(c, user.ID, query)...) items = append(items, h.commandPaletteNavigationItems(query)...) - items = append(items, h.commandPalettePersonalItems(c, u.ID, query)...) + items = append(items, h.commandPalettePersonalItems(c, user.ID, query)...) c.JSON(http.StatusOK, items) } @@ -983,10 +959,10 @@ func (h *AnimeHandler) HandleRandomAnime(c *gin.Context) { return } - user, _ := c.Get("User") inWatchlist := false - if u, ok := user.(*domain.User); ok { - watchlistMap := h.watchlistMapForIDs(c.Request.Context(), u.ID, []int64{int64(anime.MalID)}) + userID := server.CurrentUserID(c) + if userID != "" { + watchlistMap := h.watchlistMapForIDs(c.Request.Context(), userID, []int64{int64(anime.MalID)}) inWatchlist = watchlistMap[int64(anime.MalID)] } @@ -1026,7 +1002,7 @@ func (h *AnimeHandler) HandleAnimeReviews(c *gin.Context) { return } - user, _ := c.Get("User") + user := server.CurrentUser(c) if c.GetHeader("HX-Request") == "true" && page > 1 { c.HTML(http.StatusOK, "reviews.gohtml", gin.H{ diff --git a/internal/playback/handler/handler.go b/internal/playback/handler/handler.go index 3022dfb..397e790 100644 --- a/internal/playback/handler/handler.go +++ b/internal/playback/handler/handler.go @@ -50,11 +50,8 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) { ep := c.DefaultQuery("ep", "1") mode := c.DefaultQuery("mode", "sub") - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + user := server.CurrentUser(c) + userID := server.CurrentUserID(c) data, err := h.svc.BuildWatchData(c.Request.Context(), id, []string{}, ep, mode, userID) if err != nil { @@ -64,7 +61,7 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) { Anime: anime, Episodes: []domain.CanonicalEpisode{}, CurrentPath: c.Request.URL.Path, - User: currentUser(user), + User: user, CurrentEpID: ep, WatchData: domain.WatchData{ Episodes: []domain.CanonicalEpisode{}, @@ -74,7 +71,7 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) { return } - data.User = currentUser(user) + data.User = user data.CurrentPath = c.Request.URL.Path c.HTML(http.StatusOK, "watch.gohtml", data) @@ -97,11 +94,7 @@ func (h *PlaybackHandler) HandleEpisodeData(c *gin.Context) { mode := c.DefaultQuery("mode", "sub") - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) data, err := h.svc.BuildWatchData(c.Request.Context(), animeID, []string{}, episode, mode, userID) if err != nil { @@ -140,19 +133,8 @@ func (h *PlaybackHandler) HandleEpisodeData(c *gin.Context) { }) } -func currentUser(value any) *domain.User { - if user, ok := value.(*domain.User); ok { - return user - } - return nil -} - func (h *PlaybackHandler) HandleSaveProgress(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) if userID == "" { // Avoid spamming 500s for anonymous playback; progress is user-scoped. server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized") @@ -188,10 +170,10 @@ func (h *PlaybackHandler) HandleSaveProgress(c *gin.Context) { } func (h *PlaybackHandler) HandleWatchComplete(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID + userID := server.CurrentUserID(c) + if userID == "" { + server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized") + return } var req struct { @@ -222,13 +204,9 @@ func (h *PlaybackHandler) HandleWatchComplete(c *gin.Context) { } func (h *PlaybackHandler) HandleUpsertSkipSegment(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) if userID == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "login required"}) + server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized") return } diff --git a/internal/server/user.go b/internal/server/user.go new file mode 100644 index 0000000..0ee1104 --- /dev/null +++ b/internal/server/user.go @@ -0,0 +1,26 @@ +package server + +import ( + "mal/internal/domain" + + "github.com/gin-gonic/gin" +) + +func CurrentUser(c *gin.Context) *domain.User { + if c == nil { + return nil + } + user, _ := c.Get("User") + if u, ok := user.(*domain.User); ok { + return u + } + return nil +} + +func CurrentUserID(c *gin.Context) string { + u := CurrentUser(c) + if u == nil { + return "" + } + return u.ID +} diff --git a/internal/watchlist/handler.go b/internal/watchlist/handler.go index b95235c..3d641cb 100644 --- a/internal/watchlist/handler.go +++ b/internal/watchlist/handler.go @@ -25,11 +25,7 @@ func (h *WatchlistHandler) Register(r *gin.Engine) { } func (h *WatchlistHandler) HandleUpdateWatchlist(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) var body struct { AnimeID int64 `json:"animeId"` @@ -58,11 +54,7 @@ func (h *WatchlistHandler) HandleUpdateWatchlist(c *gin.Context) { } func (h *WatchlistHandler) HandleDeleteWatchlist(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) animeID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -89,11 +81,7 @@ func (h *WatchlistHandler) HandleDeleteWatchlist(c *gin.Context) { } func (h *WatchlistHandler) HandleDeleteContinueWatching(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + userID := server.CurrentUserID(c) animeID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -120,11 +108,8 @@ func (h *WatchlistHandler) HandleDeleteContinueWatching(c *gin.Context) { } func (h *WatchlistHandler) HandleGetWatchlist(c *gin.Context) { - user, _ := c.Get("User") - userID := "" - if u, ok := user.(*domain.User); ok { - userID = u.ID - } + user := server.CurrentUser(c) + userID := server.CurrentUserID(c) entries, err := h.svc.GetWatchlist(c.Request.Context(), userID) if err != nil {