diff --git a/internal/database/helpers.go b/internal/database/helpers.go index e80ed12..25e2bee 100644 --- a/internal/database/helpers.go +++ b/internal/database/helpers.go @@ -1,17 +1,43 @@ package database -import "database/sql" +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +func NullStringOr(n sql.NullString, fallback string) string { + if n.Valid && n.String != "" { + return n.String + } + return fallback +} func DisplayTitle(titleEnglish, titleJapanese sql.NullString, titleOriginal string) string { - if titleEnglish.Valid && titleEnglish.String != "" { - return titleEnglish.String - } - if titleJapanese.Valid && titleJapanese.String != "" { - return titleJapanese.String - } - return titleOriginal + return NullStringOr(titleEnglish, NullStringOr(titleJapanese, titleOriginal)) } func (r GetUserWatchListRow) DisplayTitle() string { return DisplayTitle(r.TitleEnglish, r.TitleJapanese, r.TitleOriginal) } + +func BoolPtr(b sql.NullBool) *bool { + if !b.Valid { + return nil + } + return &b.Bool +} + +func BeginTx(ctx context.Context, db *sql.DB) (*Queries, *sql.Tx, error) { + if db == nil { + return nil, nil, errors.New("database unavailable") + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to begin transaction: %w", err) + } + + return New(tx), tx, nil +} diff --git a/internal/features/watchlist/handler.go b/internal/features/watchlist/handler.go index a4d89ff..4cf8d29 100644 --- a/internal/features/watchlist/handler.go +++ b/internal/features/watchlist/handler.go @@ -5,7 +5,7 @@ import ( "errors" "log" "net/http" - "sort" + "slices" "strconv" "mal/internal/database" @@ -122,20 +122,13 @@ func (h *Handler) HandleDeleteWatchlist(w http.ResponseWriter, r *http.Request) return } - titleEnglish := "" - if anime.TitleEnglish.Valid { - titleEnglish = anime.TitleEnglish.String - } - titleJapanese := "" - if anime.TitleJapanese.Valid { - titleJapanese = anime.TitleJapanese.String - } + title := database.DisplayTitle(anime.TitleEnglish, anime.TitleJapanese, anime.TitleOriginal) airing := false if anime.Airing.Valid { airing = anime.Airing.Bool } - templates.WatchlistDropdown(int(animeID), anime.TitleOriginal, titleEnglish, titleJapanese, anime.ImageUrl, "", airing).Render(r.Context(), w) + templates.WatchlistDropdown(int(animeID), anime.TitleOriginal, title, "", anime.ImageUrl, "", airing).Render(r.Context(), w) } func (h *Handler) HandleGetWatchlist(w http.ResponseWriter, r *http.Request) { @@ -301,26 +294,34 @@ func (h *Handler) HandleImportWatchlist(w http.ResponseWriter, r *http.Request) } func (h *Handler) sortEntries(entries []database.GetUserWatchListRow, sortBy, sortOrder string) { - var less func(int, int) bool + isAsc := sortOrder == "asc" switch sortBy { case "title": - less = func(i, j int) bool { - cmp := entries[i].TitleOriginal < entries[j].TitleOriginal - if sortOrder == "asc" { - return cmp + slices.SortFunc(entries, func(a, b database.GetUserWatchListRow) int { + if a.TitleOriginal < b.TitleOriginal { + return -1 } - return !cmp + if a.TitleOriginal > b.TitleOriginal { + return 1 + } + return 0 + }) + if !isAsc { + slices.Reverse(entries) } - default: // "date" - less = func(i, j int) bool { - cmp := entries[i].UpdatedAt.After(entries[j].UpdatedAt) - if sortOrder == "asc" { - return !cmp + case "date": + slices.SortFunc(entries, func(a, b database.GetUserWatchListRow) int { + if a.UpdatedAt.After(b.UpdatedAt) { + return -1 } - return cmp + if a.UpdatedAt.Before(b.UpdatedAt) { + return 1 + } + return 0 + }) + if !isAsc { + slices.Reverse(entries) } } - - sort.SliceStable(entries, less) } diff --git a/internal/features/watchlist/service.go b/internal/features/watchlist/service.go index 66e977e..855c36e 100644 --- a/internal/features/watchlist/service.go +++ b/internal/features/watchlist/service.go @@ -152,13 +152,12 @@ func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, ani return s.db.SaveWatchProgress(ctx, clearProgress) } - tx, err := s.sqlDB.BeginTx(ctx, nil) + txQueries, tx, err := database.BeginTx(ctx, s.sqlDB) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Rollback() - txQueries := database.New(tx) if err := txQueries.DeleteContinueWatchingEntry(ctx, params); err != nil { return fmt.Errorf("failed to delete continue watching entry: %w", err) }