From d528f6b372be5187dd11ddbfab50317d0ea79cc1 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Thu, 28 May 2026 12:17:19 +0200 Subject: [PATCH] feat: add transactional InTx to playback and watchlist repos --- internal/domain/playback.go | 1 + internal/domain/watchlist.go | 1 + internal/playback/repository/repository.go | 25 ++++++- internal/playback/service/service.go | 53 ++++++++------- internal/watchlist/repository/repository.go | 25 ++++++- internal/watchlist/service/service.go | 72 +++++++++++---------- 6 files changed, 115 insertions(+), 62 deletions(-) diff --git a/internal/domain/playback.go b/internal/domain/playback.go index 0e1953e..67999a8 100644 --- a/internal/domain/playback.go +++ b/internal/domain/playback.go @@ -89,6 +89,7 @@ type EpisodeData struct { } type PlaybackRepository interface { + InTx(ctx context.Context, fn func(ctx context.Context, repo PlaybackRepository) error) error GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) GetContinueWatchingEntry(ctx context.Context, params db.GetContinueWatchingEntryParams) (db.ContinueWatchingEntry, error) SaveWatchProgress(ctx context.Context, params db.SaveWatchProgressParams) error diff --git a/internal/domain/watchlist.go b/internal/domain/watchlist.go index b29013e..9643515 100644 --- a/internal/domain/watchlist.go +++ b/internal/domain/watchlist.go @@ -21,6 +21,7 @@ type WatchlistService interface { } type WatchlistRepository interface { + InTx(ctx context.Context, fn func(ctx context.Context, repo WatchlistRepository) error) error UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) GetAnime(ctx context.Context, id int64) (db.Anime, error) UpsertWatchListEntry(ctx context.Context, arg db.UpsertWatchListEntryParams) (db.WatchListEntry, error) diff --git a/internal/playback/repository/repository.go b/internal/playback/repository/repository.go index ba69758..941c7d4 100644 --- a/internal/playback/repository/repository.go +++ b/internal/playback/repository/repository.go @@ -2,16 +2,37 @@ package repository import ( "context" + "database/sql" "mal/internal/db" "mal/internal/domain" ) type playbackRepository struct { + sqlDB *sql.DB queries *db.Queries } -func NewPlaybackRepository(queries *db.Queries) domain.PlaybackRepository { - return &playbackRepository{queries: queries} +func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRepository { + return &playbackRepository{sqlDB: sqlDB, queries: queries} +} + +func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error { + if r.sqlDB == nil { + return fn(ctx, r) + } + + tx, err := r.sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + + txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + if err := fn(ctx, txRepo); err != nil { + _ = tx.Rollback() + return err + } + + return tx.Commit() } func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) { diff --git a/internal/playback/service/service.go b/internal/playback/service/service.go index 06c9567..eb21848 100644 --- a/internal/playback/service/service.go +++ b/internal/playback/service/service.go @@ -301,38 +301,41 @@ func (s *playbackService) BuildWatchData(ctx context.Context, animeID int, title } func (s *playbackService) CompleteAnime(ctx context.Context, userID string, animeID int64) error { - entry, err := s.repo.GetWatchListEntry(ctx, db.GetWatchListEntryParams{ - UserID: userID, - AnimeID: animeID, - }) - if err != nil || entry.Status != "completed" { - _, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ - ID: uuid.New().String(), + if err := s.repo.InTx(ctx, func(txCtx context.Context, repo domain.PlaybackRepository) error { + entry, err := repo.GetWatchListEntry(txCtx, db.GetWatchListEntryParams{ + UserID: userID, + AnimeID: animeID, + }) + if err != nil || entry.Status != "completed" { + _, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{ + ID: uuid.New().String(), + UserID: userID, + AnimeID: animeID, + Status: "completed", + CurrentEpisode: sql.NullInt64{Valid: false}, + CurrentTimeSeconds: 0, + }) + if err != nil { + return err + } + } + + if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return err + } + return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{ UserID: userID, AnimeID: animeID, - Status: "completed", CurrentEpisode: sql.NullInt64{Valid: false}, CurrentTimeSeconds: 0, }) - if err != nil { - return err - } + }); err != nil { + return err } - if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{ - UserID: userID, - AnimeID: animeID, - }); err != nil { - return err - } - if err := s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{ - UserID: userID, - AnimeID: animeID, - CurrentEpisode: sql.NullInt64{Valid: false}, - CurrentTimeSeconds: 0, - }); err != nil { - return err - } if err := s.auditSvc.Record(ctx, domain.AuditEvent{ UserID: userID, Action: "watch_completed", diff --git a/internal/watchlist/repository/repository.go b/internal/watchlist/repository/repository.go index 0136ee4..d772318 100644 --- a/internal/watchlist/repository/repository.go +++ b/internal/watchlist/repository/repository.go @@ -2,16 +2,37 @@ package repository import ( "context" + "database/sql" "mal/internal/db" "mal/internal/domain" ) type watchlistRepository struct { + sqlDB *sql.DB queries *db.Queries } -func NewWatchlistRepository(queries *db.Queries) domain.WatchlistRepository { - return &watchlistRepository{queries: queries} +func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.WatchlistRepository { + return &watchlistRepository{sqlDB: sqlDB, queries: queries} +} + +func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error { + if r.sqlDB == nil { + return fn(ctx, r) + } + + tx, err := r.sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + + txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + if err := fn(ctx, txRepo); err != nil { + _ = tx.Rollback() + return err + } + + return tx.Commit() } func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) { diff --git a/internal/watchlist/service/service.go b/internal/watchlist/service/service.go index 7c85e3f..6720e16 100644 --- a/internal/watchlist/service/service.go +++ b/internal/watchlist/service/service.go @@ -20,31 +20,35 @@ func NewWatchlistService(repo domain.WatchlistRepository, jikan *jikan.Client) d } func (s *watchlistService) UpdateEntry(ctx context.Context, userID string, animeID int64, status string) error { - _, err := s.repo.GetAnime(ctx, animeID) - if err != nil { - anime, err := s.jikan.GetAnimeByID(ctx, int(animeID)) - if err != nil { - return err - } - if _, err := s.repo.UpsertAnime(ctx, db.UpsertAnimeParams{ - ID: int64(anime.MalID), - TitleOriginal: anime.Title, - TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, - TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, - ImageUrl: anime.ImageURL(), - Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, - }); err != nil { - return err - } + anime, fetchErr := s.jikan.GetAnimeByID(ctx, int(animeID)) + if fetchErr != nil { + // still allow status updates for already-known anime rows + anime = jikan.Anime{} } - _, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ - ID: uuid.New().String(), - UserID: userID, - AnimeID: animeID, - Status: status, + return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error { + _, err := repo.GetAnime(txCtx, animeID) + if err != nil && fetchErr == nil { + if _, err := repo.UpsertAnime(txCtx, db.UpsertAnimeParams{ + ID: int64(anime.MalID), + TitleOriginal: anime.Title, + TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, + TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, + ImageUrl: anime.ImageURL(), + Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, + }); err != nil { + return err + } + } + + _, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{ + ID: uuid.New().String(), + UserID: userID, + AnimeID: animeID, + Status: status, + }) + return err }) - return err } func (s *watchlistService) RemoveEntry(ctx context.Context, userID string, animeID int64) error { @@ -99,16 +103,18 @@ func (s *watchlistService) GetContinueWatchingEntry(ctx context.Context, userID } func (s *watchlistService) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error { - if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{ - UserID: userID, - AnimeID: animeID, - }); err != nil { - return err - } - return s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{ - UserID: userID, - AnimeID: animeID, - CurrentEpisode: sql.NullInt64{Valid: false}, - CurrentTimeSeconds: 0, + return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error { + if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return err + } + return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{ + UserID: userID, + AnimeID: animeID, + CurrentEpisode: sql.NullInt64{Valid: false}, + CurrentTimeSeconds: 0, + }) }) }