refactor: extract CurrentUser and CurrentUserID helpers

This commit is contained in:
2026-05-28 12:51:11 +02:00
committed by Milas Holsting
parent 8454d01b09
commit bf28c307c9
4 changed files with 67 additions and 102 deletions

View File

@@ -187,7 +187,7 @@ func (h *AnimeHandler) HandleProducers(c *gin.Context) {
} }
func (h *AnimeHandler) HandleCatalog(c *gin.Context) { func (h *AnimeHandler) HandleCatalog(c *gin.Context) {
user, _ := c.Get("User") user := server.CurrentUser(c)
c.HTML(http.StatusOK, "index.gohtml", gin.H{ c.HTML(http.StatusOK, "index.gohtml", gin.H{
"CurrentPath": "/", "CurrentPath": "/",
@@ -209,11 +209,7 @@ func (h *AnimeHandler) HandleCatalogContinue(c *gin.Context) {
} }
func (h *AnimeHandler) renderCatalogSection(c *gin.Context, section string) { func (h *AnimeHandler) renderCatalogSection(c *gin.Context, section string) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
data, err := h.svc.GetCatalogSection(c.Request.Context(), userID, section) data, err := h.svc.GetCatalogSection(c.Request.Context(), userID, section)
if err != nil { if err != nil {
observability.Warn( observability.Warn(
@@ -239,7 +235,7 @@ func (h *AnimeHandler) renderCatalogSection(c *gin.Context, section string) {
} }
func (h *AnimeHandler) HandleDiscover(c *gin.Context) { func (h *AnimeHandler) HandleDiscover(c *gin.Context) {
user, _ := c.Get("User") user := server.CurrentUser(c)
c.HTML(http.StatusOK, "discover.gohtml", gin.H{ c.HTML(http.StatusOK, "discover.gohtml", gin.H{
"CurrentPath": "/discover", "CurrentPath": "/discover",
"User": user, "User": user,
@@ -259,11 +255,7 @@ func (h *AnimeHandler) HandleDiscoverTop(c *gin.Context) {
} }
func (h *AnimeHandler) HandleDiscoverForYou(c *gin.Context) { func (h *AnimeHandler) HandleDiscoverForYou(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
data, err := h.svc.GetDiscoverForYou(c.Request.Context(), userID) data, err := h.svc.GetDiscoverForYou(c.Request.Context(), userID)
if err != nil { if err != nil {
@@ -289,11 +281,7 @@ func (h *AnimeHandler) HandleDiscoverForYou(c *gin.Context) {
} }
func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) { func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
data, err := h.svc.GetDiscoverSection(c.Request.Context(), userID, section) data, err := h.svc.GetDiscoverSection(c.Request.Context(), userID, section)
if err != nil { if err != nil {
observability.Warn( observability.Warn(
@@ -319,7 +307,7 @@ func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) {
} }
func (h *AnimeHandler) HandleSchedule(c *gin.Context) { func (h *AnimeHandler) HandleSchedule(c *gin.Context) {
user, _ := c.Get("User") user := server.CurrentUser(c)
year, week := parseYearWeek(c) year, week := parseYearWeek(c)
c.HTML(http.StatusOK, "schedule.gohtml", gin.H{ c.HTML(http.StatusOK, "schedule.gohtml", gin.H{
"CurrentPath": "/schedule", "CurrentPath": "/schedule",
@@ -515,11 +503,8 @@ func (h *AnimeHandler) HandleBrowse(c *gin.Context) {
return return
} }
user, _ := c.Get("User") user := server.CurrentUser(c)
userID := "" userID := server.CurrentUserID(c)
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
animes := wrapAnimes(res.Animes) animes := wrapAnimes(res.Animes)
watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes) watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes)
@@ -655,19 +640,19 @@ func (h *AnimeHandler) HandleAnimeDetails(c *gin.Context) {
return return
} }
user, _ := c.Get("User") user := server.CurrentUser(c)
status := "" status := ""
var watchlistIDs []int64 var watchlistIDs []int64
ep := 0 ep := 0
var cwSeconds float64 var cwSeconds float64
if u, ok := user.(*domain.User); ok { if user != nil {
entry, err := h.watchlistSvc.GetWatchListEntry(c.Request.Context(), u.ID, int64(id)) entry, err := h.watchlistSvc.GetWatchListEntry(c.Request.Context(), user.ID, int64(id))
if err == nil { if err == nil {
status = entry.Status status = entry.Status
watchlistIDs = []int64{entry.AnimeID} watchlistIDs = []int64{entry.AnimeID}
} }
cwEntry, err := h.watchlistSvc.GetContinueWatchingEntry(c.Request.Context(), u.ID, int64(id)) cwEntry, err := h.watchlistSvc.GetContinueWatchingEntry(c.Request.Context(), user.ID, int64(id))
if err == nil && cwEntry.CurrentEpisode.Valid { if err == nil && cwEntry.CurrentEpisode.Valid {
ep = int(cwEntry.CurrentEpisode.Int64) ep = int(cwEntry.CurrentEpisode.Int64)
cwSeconds = cwEntry.CurrentTimeSeconds cwSeconds = cwEntry.CurrentTimeSeconds
@@ -692,11 +677,7 @@ func (h *AnimeHandler) HandleHTMLWatchOrder(c *gin.Context) {
return return
} }
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
relationsCtx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) relationsCtx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
defer cancel() defer cancel()
@@ -745,11 +726,7 @@ func (h *AnimeHandler) HandleQuickSearch(c *gin.Context) {
return return
} }
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
animes := wrapAnimes(res.Animes) animes := wrapAnimes(res.Animes)
watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes) watchlistMap := h.watchlistMapForAnimes(c.Request.Context(), userID, animes)
@@ -787,9 +764,8 @@ type commandPaletteItem struct {
} }
func (h *AnimeHandler) HandleCommandPalette(c *gin.Context) { func (h *AnimeHandler) HandleCommandPalette(c *gin.Context) {
user, _ := c.Get("User") user := server.CurrentUser(c)
u, ok := user.(*domain.User) if user == nil {
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return return
} }
@@ -812,15 +788,15 @@ func (h *AnimeHandler) HandleCommandPalette(c *gin.Context) {
} }
items = append(items, h.commandPaletteNavigationItems(query)...) items = append(items, h.commandPaletteNavigationItems(query)...)
items = append(items, h.commandPaletteContinueItems(c, u.ID, query)...) items = append(items, h.commandPaletteContinueItems(c, user.ID, query)...)
items = append(items, h.commandPalettePersonalItems(c, u.ID, query)...) items = append(items, h.commandPalettePersonalItems(c, user.ID, query)...)
c.JSON(http.StatusOK, items) c.JSON(http.StatusOK, items)
return return
} }
items = append(items, h.commandPaletteContinueItems(c, u.ID, query)...) items = append(items, h.commandPaletteContinueItems(c, user.ID, query)...)
items = append(items, h.commandPaletteNavigationItems(query)...) items = append(items, h.commandPaletteNavigationItems(query)...)
items = append(items, h.commandPalettePersonalItems(c, u.ID, query)...) items = append(items, h.commandPalettePersonalItems(c, user.ID, query)...)
c.JSON(http.StatusOK, items) c.JSON(http.StatusOK, items)
} }
@@ -983,10 +959,10 @@ func (h *AnimeHandler) HandleRandomAnime(c *gin.Context) {
return return
} }
user, _ := c.Get("User")
inWatchlist := false inWatchlist := false
if u, ok := user.(*domain.User); ok { userID := server.CurrentUserID(c)
watchlistMap := h.watchlistMapForIDs(c.Request.Context(), u.ID, []int64{int64(anime.MalID)}) if userID != "" {
watchlistMap := h.watchlistMapForIDs(c.Request.Context(), userID, []int64{int64(anime.MalID)})
inWatchlist = watchlistMap[int64(anime.MalID)] inWatchlist = watchlistMap[int64(anime.MalID)]
} }
@@ -1026,7 +1002,7 @@ func (h *AnimeHandler) HandleAnimeReviews(c *gin.Context) {
return return
} }
user, _ := c.Get("User") user := server.CurrentUser(c)
if c.GetHeader("HX-Request") == "true" && page > 1 { if c.GetHeader("HX-Request") == "true" && page > 1 {
c.HTML(http.StatusOK, "reviews.gohtml", gin.H{ c.HTML(http.StatusOK, "reviews.gohtml", gin.H{

View File

@@ -50,11 +50,8 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) {
ep := c.DefaultQuery("ep", "1") ep := c.DefaultQuery("ep", "1")
mode := c.DefaultQuery("mode", "sub") mode := c.DefaultQuery("mode", "sub")
user, _ := c.Get("User") user := server.CurrentUser(c)
userID := "" userID := server.CurrentUserID(c)
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
data, err := h.svc.BuildWatchData(c.Request.Context(), id, []string{}, ep, mode, userID) data, err := h.svc.BuildWatchData(c.Request.Context(), id, []string{}, ep, mode, userID)
if err != nil { if err != nil {
@@ -64,7 +61,7 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) {
Anime: anime, Anime: anime,
Episodes: []domain.CanonicalEpisode{}, Episodes: []domain.CanonicalEpisode{},
CurrentPath: c.Request.URL.Path, CurrentPath: c.Request.URL.Path,
User: currentUser(user), User: user,
CurrentEpID: ep, CurrentEpID: ep,
WatchData: domain.WatchData{ WatchData: domain.WatchData{
Episodes: []domain.CanonicalEpisode{}, Episodes: []domain.CanonicalEpisode{},
@@ -74,7 +71,7 @@ func (h *PlaybackHandler) HandleWatchPage(c *gin.Context) {
return return
} }
data.User = currentUser(user) data.User = user
data.CurrentPath = c.Request.URL.Path data.CurrentPath = c.Request.URL.Path
c.HTML(http.StatusOK, "watch.gohtml", data) c.HTML(http.StatusOK, "watch.gohtml", data)
@@ -97,11 +94,7 @@ func (h *PlaybackHandler) HandleEpisodeData(c *gin.Context) {
mode := c.DefaultQuery("mode", "sub") mode := c.DefaultQuery("mode", "sub")
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
data, err := h.svc.BuildWatchData(c.Request.Context(), animeID, []string{}, episode, mode, userID) data, err := h.svc.BuildWatchData(c.Request.Context(), animeID, []string{}, episode, mode, userID)
if err != nil { if err != nil {
@@ -140,19 +133,8 @@ func (h *PlaybackHandler) HandleEpisodeData(c *gin.Context) {
}) })
} }
func currentUser(value any) *domain.User {
if user, ok := value.(*domain.User); ok {
return user
}
return nil
}
func (h *PlaybackHandler) HandleSaveProgress(c *gin.Context) { func (h *PlaybackHandler) HandleSaveProgress(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
if userID == "" { if userID == "" {
// Avoid spamming 500s for anonymous playback; progress is user-scoped. // Avoid spamming 500s for anonymous playback; progress is user-scoped.
server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized") server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized")
@@ -188,10 +170,10 @@ func (h *PlaybackHandler) HandleSaveProgress(c *gin.Context) {
} }
func (h *PlaybackHandler) HandleWatchComplete(c *gin.Context) { func (h *PlaybackHandler) HandleWatchComplete(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := "" if userID == "" {
if u, ok := user.(*domain.User); ok { server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized")
userID = u.ID return
} }
var req struct { var req struct {
@@ -222,13 +204,9 @@ func (h *PlaybackHandler) HandleWatchComplete(c *gin.Context) {
} }
func (h *PlaybackHandler) HandleUpsertSkipSegment(c *gin.Context) { func (h *PlaybackHandler) HandleUpsertSkipSegment(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
if userID == "" { if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "login required"}) server.RespondHTMLOrJSONError(c, http.StatusUnauthorized, "unauthorized")
return return
} }

26
internal/server/user.go Normal file
View File

@@ -0,0 +1,26 @@
package server
import (
"mal/internal/domain"
"github.com/gin-gonic/gin"
)
func CurrentUser(c *gin.Context) *domain.User {
if c == nil {
return nil
}
user, _ := c.Get("User")
if u, ok := user.(*domain.User); ok {
return u
}
return nil
}
func CurrentUserID(c *gin.Context) string {
u := CurrentUser(c)
if u == nil {
return ""
}
return u.ID
}

View File

@@ -25,11 +25,7 @@ func (h *WatchlistHandler) Register(r *gin.Engine) {
} }
func (h *WatchlistHandler) HandleUpdateWatchlist(c *gin.Context) { func (h *WatchlistHandler) HandleUpdateWatchlist(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
var body struct { var body struct {
AnimeID int64 `json:"animeId"` AnimeID int64 `json:"animeId"`
@@ -58,11 +54,7 @@ func (h *WatchlistHandler) HandleUpdateWatchlist(c *gin.Context) {
} }
func (h *WatchlistHandler) HandleDeleteWatchlist(c *gin.Context) { func (h *WatchlistHandler) HandleDeleteWatchlist(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
animeID, err := strconv.ParseInt(c.Param("id"), 10, 64) animeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -89,11 +81,7 @@ func (h *WatchlistHandler) HandleDeleteWatchlist(c *gin.Context) {
} }
func (h *WatchlistHandler) HandleDeleteContinueWatching(c *gin.Context) { func (h *WatchlistHandler) HandleDeleteContinueWatching(c *gin.Context) {
user, _ := c.Get("User") userID := server.CurrentUserID(c)
userID := ""
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
animeID, err := strconv.ParseInt(c.Param("id"), 10, 64) animeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -120,11 +108,8 @@ func (h *WatchlistHandler) HandleDeleteContinueWatching(c *gin.Context) {
} }
func (h *WatchlistHandler) HandleGetWatchlist(c *gin.Context) { func (h *WatchlistHandler) HandleGetWatchlist(c *gin.Context) {
user, _ := c.Get("User") user := server.CurrentUser(c)
userID := "" userID := server.CurrentUserID(c)
if u, ok := user.(*domain.User); ok {
userID = u.ID
}
entries, err := h.svc.GetWatchlist(c.Request.Context(), userID) entries, err := h.svc.GetWatchlist(c.Request.Context(), userID)
if err != nil { if err != nil {