feat: add transactional InTx to playback and watchlist repos

This commit is contained in:
2026-05-28 12:17:19 +02:00
parent 4329bce4a7
commit dd4c7f80f3
6 changed files with 115 additions and 62 deletions

View File

@@ -89,6 +89,7 @@ type EpisodeData struct {
} }
type PlaybackRepository interface { type PlaybackRepository interface {
InTx(ctx context.Context, fn func(ctx context.Context, repo PlaybackRepository) error) error
GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error)
GetContinueWatchingEntry(ctx context.Context, params db.GetContinueWatchingEntryParams) (db.ContinueWatchingEntry, error) GetContinueWatchingEntry(ctx context.Context, params db.GetContinueWatchingEntryParams) (db.ContinueWatchingEntry, error)
SaveWatchProgress(ctx context.Context, params db.SaveWatchProgressParams) error SaveWatchProgress(ctx context.Context, params db.SaveWatchProgressParams) error

View File

@@ -21,6 +21,7 @@ type WatchlistService interface {
} }
type WatchlistRepository interface { type WatchlistRepository interface {
InTx(ctx context.Context, fn func(ctx context.Context, repo WatchlistRepository) error) error
UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error)
GetAnime(ctx context.Context, id int64) (db.Anime, error) GetAnime(ctx context.Context, id int64) (db.Anime, error)
UpsertWatchListEntry(ctx context.Context, arg db.UpsertWatchListEntryParams) (db.WatchListEntry, error) UpsertWatchListEntry(ctx context.Context, arg db.UpsertWatchListEntryParams) (db.WatchListEntry, error)

View File

@@ -2,16 +2,37 @@ package repository
import ( import (
"context" "context"
"database/sql"
"mal/internal/db" "mal/internal/db"
"mal/internal/domain" "mal/internal/domain"
) )
type playbackRepository struct { type playbackRepository struct {
sqlDB *sql.DB
queries *db.Queries queries *db.Queries
} }
func NewPlaybackRepository(queries *db.Queries) domain.PlaybackRepository { func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRepository {
return &playbackRepository{queries: queries} return &playbackRepository{sqlDB: sqlDB, queries: queries}
}
func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error {
if r.sqlDB == nil {
return fn(ctx, r)
}
tx, err := r.sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
if err := fn(ctx, txRepo); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
} }
func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) { func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) {

View File

@@ -301,38 +301,41 @@ func (s *playbackService) BuildWatchData(ctx context.Context, animeID int, title
} }
func (s *playbackService) CompleteAnime(ctx context.Context, userID string, animeID int64) error { func (s *playbackService) CompleteAnime(ctx context.Context, userID string, animeID int64) error {
entry, err := s.repo.GetWatchListEntry(ctx, db.GetWatchListEntryParams{ if err := s.repo.InTx(ctx, func(txCtx context.Context, repo domain.PlaybackRepository) error {
UserID: userID, entry, err := repo.GetWatchListEntry(txCtx, db.GetWatchListEntryParams{
AnimeID: animeID, UserID: userID,
}) AnimeID: animeID,
if err != nil || entry.Status != "completed" { })
_, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ if err != nil || entry.Status != "completed" {
ID: uuid.New().String(), _, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
Status: "completed",
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
})
if err != nil {
return err
}
}
if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}); err != nil {
return err
}
return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{
UserID: userID, UserID: userID,
AnimeID: animeID, AnimeID: animeID,
Status: "completed",
CurrentEpisode: sql.NullInt64{Valid: false}, CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0, CurrentTimeSeconds: 0,
}) })
if err != nil { }); err != nil {
return err return err
}
} }
if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}); err != nil {
return err
}
if err := s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{
UserID: userID,
AnimeID: animeID,
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
}); err != nil {
return err
}
if err := s.auditSvc.Record(ctx, domain.AuditEvent{ if err := s.auditSvc.Record(ctx, domain.AuditEvent{
UserID: userID, UserID: userID,
Action: "watch_completed", Action: "watch_completed",

View File

@@ -2,16 +2,37 @@ package repository
import ( import (
"context" "context"
"database/sql"
"mal/internal/db" "mal/internal/db"
"mal/internal/domain" "mal/internal/domain"
) )
type watchlistRepository struct { type watchlistRepository struct {
sqlDB *sql.DB
queries *db.Queries queries *db.Queries
} }
func NewWatchlistRepository(queries *db.Queries) domain.WatchlistRepository { func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.WatchlistRepository {
return &watchlistRepository{queries: queries} return &watchlistRepository{sqlDB: sqlDB, queries: queries}
}
func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error {
if r.sqlDB == nil {
return fn(ctx, r)
}
tx, err := r.sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
if err := fn(ctx, txRepo); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
} }
func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) { func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {

View File

@@ -20,31 +20,35 @@ func NewWatchlistService(repo domain.WatchlistRepository, jikan *jikan.Client) d
} }
func (s *watchlistService) UpdateEntry(ctx context.Context, userID string, animeID int64, status string) error { func (s *watchlistService) UpdateEntry(ctx context.Context, userID string, animeID int64, status string) error {
_, err := s.repo.GetAnime(ctx, animeID) anime, fetchErr := s.jikan.GetAnimeByID(ctx, int(animeID))
if err != nil { if fetchErr != nil {
anime, err := s.jikan.GetAnimeByID(ctx, int(animeID)) // still allow status updates for already-known anime rows
if err != nil { anime = jikan.Anime{}
return err
}
if _, err := s.repo.UpsertAnime(ctx, db.UpsertAnimeParams{
ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
ImageUrl: anime.ImageURL(),
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
}); err != nil {
return err
}
} }
_, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error {
ID: uuid.New().String(), _, err := repo.GetAnime(txCtx, animeID)
UserID: userID, if err != nil && fetchErr == nil {
AnimeID: animeID, if _, err := repo.UpsertAnime(txCtx, db.UpsertAnimeParams{
Status: status, ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
ImageUrl: anime.ImageURL(),
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
}); err != nil {
return err
}
}
_, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
Status: status,
})
return err
}) })
return err
} }
func (s *watchlistService) RemoveEntry(ctx context.Context, userID string, animeID int64) error { func (s *watchlistService) RemoveEntry(ctx context.Context, userID string, animeID int64) error {
@@ -99,16 +103,18 @@ func (s *watchlistService) GetContinueWatchingEntry(ctx context.Context, userID
} }
func (s *watchlistService) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error { func (s *watchlistService) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error {
if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{ return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error {
UserID: userID, if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{
AnimeID: animeID, UserID: userID,
}); err != nil { AnimeID: animeID,
return err }); err != nil {
} return err
return s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{ }
UserID: userID, return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{
AnimeID: animeID, UserID: userID,
CurrentEpisode: sql.NullInt64{Valid: false}, AnimeID: animeID,
CurrentTimeSeconds: 0, CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
})
}) })
} }