refactor(watchlist): sortEntries, null checks, and beginTx
This commit is contained in:
@@ -1,17 +1,43 @@
|
||||
package database
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NullStringOr(n sql.NullString, fallback string) string {
|
||||
if n.Valid && n.String != "" {
|
||||
return n.String
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func DisplayTitle(titleEnglish, titleJapanese sql.NullString, titleOriginal string) string {
|
||||
if titleEnglish.Valid && titleEnglish.String != "" {
|
||||
return titleEnglish.String
|
||||
}
|
||||
if titleJapanese.Valid && titleJapanese.String != "" {
|
||||
return titleJapanese.String
|
||||
}
|
||||
return titleOriginal
|
||||
return NullStringOr(titleEnglish, NullStringOr(titleJapanese, titleOriginal))
|
||||
}
|
||||
|
||||
func (r GetUserWatchListRow) DisplayTitle() string {
|
||||
return DisplayTitle(r.TitleEnglish, r.TitleJapanese, r.TitleOriginal)
|
||||
}
|
||||
|
||||
func BoolPtr(b sql.NullBool) *bool {
|
||||
if !b.Valid {
|
||||
return nil
|
||||
}
|
||||
return &b.Bool
|
||||
}
|
||||
|
||||
func BeginTx(ctx context.Context, db *sql.DB) (*Queries, *sql.Tx, error) {
|
||||
if db == nil {
|
||||
return nil, nil, errors.New("database unavailable")
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
return New(tx), tx, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"sort"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"mal/internal/database"
|
||||
@@ -122,20 +122,13 @@ func (h *Handler) HandleDeleteWatchlist(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
titleEnglish := ""
|
||||
if anime.TitleEnglish.Valid {
|
||||
titleEnglish = anime.TitleEnglish.String
|
||||
}
|
||||
titleJapanese := ""
|
||||
if anime.TitleJapanese.Valid {
|
||||
titleJapanese = anime.TitleJapanese.String
|
||||
}
|
||||
title := database.DisplayTitle(anime.TitleEnglish, anime.TitleJapanese, anime.TitleOriginal)
|
||||
airing := false
|
||||
if anime.Airing.Valid {
|
||||
airing = anime.Airing.Bool
|
||||
}
|
||||
|
||||
templates.WatchlistDropdown(int(animeID), anime.TitleOriginal, titleEnglish, titleJapanese, anime.ImageUrl, "", airing).Render(r.Context(), w)
|
||||
templates.WatchlistDropdown(int(animeID), anime.TitleOriginal, title, "", anime.ImageUrl, "", airing).Render(r.Context(), w)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleGetWatchlist(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -301,26 +294,34 @@ func (h *Handler) HandleImportWatchlist(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (h *Handler) sortEntries(entries []database.GetUserWatchListRow, sortBy, sortOrder string) {
|
||||
var less func(int, int) bool
|
||||
isAsc := sortOrder == "asc"
|
||||
|
||||
switch sortBy {
|
||||
case "title":
|
||||
less = func(i, j int) bool {
|
||||
cmp := entries[i].TitleOriginal < entries[j].TitleOriginal
|
||||
if sortOrder == "asc" {
|
||||
return cmp
|
||||
slices.SortFunc(entries, func(a, b database.GetUserWatchListRow) int {
|
||||
if a.TitleOriginal < b.TitleOriginal {
|
||||
return -1
|
||||
}
|
||||
return !cmp
|
||||
if a.TitleOriginal > b.TitleOriginal {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
if !isAsc {
|
||||
slices.Reverse(entries)
|
||||
}
|
||||
default: // "date"
|
||||
less = func(i, j int) bool {
|
||||
cmp := entries[i].UpdatedAt.After(entries[j].UpdatedAt)
|
||||
if sortOrder == "asc" {
|
||||
return !cmp
|
||||
case "date":
|
||||
slices.SortFunc(entries, func(a, b database.GetUserWatchListRow) int {
|
||||
if a.UpdatedAt.After(b.UpdatedAt) {
|
||||
return -1
|
||||
}
|
||||
return cmp
|
||||
if a.UpdatedAt.Before(b.UpdatedAt) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
if !isAsc {
|
||||
slices.Reverse(entries)
|
||||
}
|
||||
}
|
||||
|
||||
sort.SliceStable(entries, less)
|
||||
}
|
||||
|
||||
@@ -152,13 +152,12 @@ func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, ani
|
||||
return s.db.SaveWatchProgress(ctx, clearProgress)
|
||||
}
|
||||
|
||||
tx, err := s.sqlDB.BeginTx(ctx, nil)
|
||||
txQueries, tx, err := database.BeginTx(ctx, s.sqlDB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
txQueries := database.New(tx)
|
||||
if err := txQueries.DeleteContinueWatchingEntry(ctx, params); err != nil {
|
||||
return fmt.Errorf("failed to delete continue watching entry: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user