From 42247214dd69f33809cc98275ccb402932f71c05 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Fri, 10 Apr 2026 17:22:34 +0200 Subject: [PATCH] security: validate watchlist inputs --- internal/features/watchlist/handler.go | 26 +++++++++++++++++++------- internal/features/watchlist/service.go | 26 ++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/internal/features/watchlist/handler.go b/internal/features/watchlist/handler.go index 2804721..e6c07d3 100644 --- a/internal/features/watchlist/handler.go +++ b/internal/features/watchlist/handler.go @@ -2,7 +2,7 @@ package watchlist import ( "encoding/json" - "fmt" + "errors" "log" "net/http" "strconv" @@ -50,7 +50,7 @@ func (h *Handler) HandleUpdateWatchlist(w http.ResponseWriter, r *http.Request) log.Printf("watchlist add: user_id=%s, anime_id=%s, title=%s", user.ID, animeIDStr, animeTitle) animeID, err := strconv.ParseInt(animeIDStr, 10, 64) - if err != nil { + if err != nil || animeID <= 0 { http.Error(w, "invalid anime ID", http.StatusBadRequest) return } @@ -66,7 +66,12 @@ func (h *Handler) HandleUpdateWatchlist(w http.ResponseWriter, r *http.Request) } if err := h.svc.AddEntry(r.Context(), user.ID, req); err != nil { - http.Error(w, fmt.Sprintf("failed to update watchlist: %v", err), http.StatusInternalServerError) + if errors.Is(err, ErrInvalidAnimeID) || errors.Is(err, ErrInvalidStatus) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + log.Printf("watchlist add failed: user_id=%s anime_id=%d err=%v", user.ID, animeID, err) + http.Error(w, "failed to update watchlist", http.StatusInternalServerError) return } @@ -88,14 +93,19 @@ func (h *Handler) HandleDeleteWatchlist(w http.ResponseWriter, r *http.Request) path := r.URL.Path[len("/api/watchlist/"):] animeID, err := strconv.ParseInt(path, 10, 64) - if err != nil { + if err != nil || animeID <= 0 { http.Error(w, "invalid anime ID", http.StatusBadRequest) return } anime, err := h.svc.RemoveEntry(r.Context(), user.ID, animeID) if err != nil { - http.Error(w, fmt.Sprintf("failed to delete from watchlist: %v", err), http.StatusInternalServerError) + if errors.Is(err, ErrInvalidAnimeID) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + log.Printf("watchlist delete failed: user_id=%s anime_id=%d err=%v", user.ID, animeID, err) + http.Error(w, "failed to delete from watchlist", http.StatusInternalServerError) return } @@ -150,7 +160,8 @@ func (h *Handler) HandleGetWatchlist(w http.ResponseWriter, r *http.Request) { entries, err := h.svc.GetUserWatchlist(r.Context(), user.ID) if err != nil { - http.Error(w, fmt.Sprintf("failed to fetch watchlist: %v", err), http.StatusInternalServerError) + log.Printf("watchlist fetch failed: user_id=%s err=%v", user.ID, err) + http.Error(w, "failed to fetch watchlist", http.StatusInternalServerError) return } @@ -193,7 +204,8 @@ func (h *Handler) HandleExportWatchlist(w http.ResponseWriter, r *http.Request) export, err := h.svc.Export(r.Context(), user.ID) if err != nil { - http.Error(w, fmt.Sprintf("failed to export: %v", err), http.StatusInternalServerError) + log.Printf("watchlist export failed: user_id=%s err=%v", user.ID, err) + http.Error(w, "failed to export", http.StatusInternalServerError) return } diff --git a/internal/features/watchlist/service.go b/internal/features/watchlist/service.go index 718d701..f8e172c 100644 --- a/internal/features/watchlist/service.go +++ b/internal/features/watchlist/service.go @@ -3,6 +3,7 @@ package watchlist import ( "context" "database/sql" + "errors" "fmt" "time" @@ -15,6 +16,19 @@ type Service struct { db database.Querier } +var ( + ErrInvalidAnimeID = errors.New("invalid anime ID") + ErrInvalidStatus = errors.New("invalid watchlist status") +) + +var validStatuses = map[string]struct{}{ + "watching": {}, + "completed": {}, + "on_hold": {}, + "dropped": {}, + "plan_to_watch": {}, +} + func NewService(db database.Querier) *Service { return &Service{db: db} } @@ -30,8 +44,12 @@ type AddRequest struct { } func (s *Service) AddEntry(ctx context.Context, userID string, req AddRequest) error { - if req.AnimeID == 0 { - return fmt.Errorf("invalid anime ID") + if req.AnimeID <= 0 { + return ErrInvalidAnimeID + } + + if _, ok := validStatuses[req.Status]; !ok { + return ErrInvalidStatus } _, err := s.db.UpsertAnime(ctx, database.UpsertAnimeParams{ @@ -62,8 +80,8 @@ func (s *Service) AddEntry(ctx context.Context, userID string, req AddRequest) e } func (s *Service) RemoveEntry(ctx context.Context, userID string, animeID int64) (database.Anime, error) { - if animeID == 0 { - return database.Anime{}, fmt.Errorf("invalid anime ID") + if animeID <= 0 { + return database.Anime{}, ErrInvalidAnimeID } anime, err := s.db.GetAnime(ctx, animeID)